mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Add test for corrupt or invalid encryption (#690)
This commit is contained in:
parent
9d27f0c772
commit
0eb468ec61
@ -80,17 +80,20 @@ def _make_noise_handshake_pkt(proto: NoiseConnection) -> bytes:
|
||||
|
||||
|
||||
def _make_encrypted_packet(
|
||||
proto: NoiseConnection, msg_type: int, encrypted_payload: bytes
|
||||
proto: NoiseConnection, msg_type: int, payload: bytes
|
||||
) -> bytes:
|
||||
msg_type = 42
|
||||
msg_type_high = (msg_type >> 8) & 0xFF
|
||||
msg_type_low = msg_type & 0xFF
|
||||
msg_length = len(encrypted_payload)
|
||||
msg_length = len(payload)
|
||||
msg_length_high = (msg_length >> 8) & 0xFF
|
||||
msg_length_low = msg_length & 0xFF
|
||||
msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low))
|
||||
encrypted_payload = proto.encrypt(msg_header + b"from device")
|
||||
encrypted_payload = proto.encrypt(msg_header + payload)
|
||||
return _make_encrypted_packet_from_encrypted_payload(encrypted_payload)
|
||||
|
||||
|
||||
def _make_encrypted_packet_from_encrypted_payload(encrypted_payload: bytes) -> bytes:
|
||||
preamble = 1
|
||||
encrypted_pkg_length = len(encrypted_payload)
|
||||
encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF
|
||||
@ -573,6 +576,69 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
mock_data_received(helper, encrypted_packet)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_bad_encryption(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test the noise frame helper closes connection on encryption error."""
|
||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
writes = []
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
|
||||
connection, packets = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
connection=connection,
|
||||
noise_psk=noise_psk,
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
)
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
handshake_pkt = writes.pop()
|
||||
|
||||
encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt)
|
||||
decrypted = proto.read_message(encrypted_payload)
|
||||
assert decrypted == b""
|
||||
|
||||
hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
||||
handshake_with_header = _make_noise_handshake_pkt(proto)
|
||||
mock_data_received(helper, handshake_with_header)
|
||||
|
||||
assert not writes
|
||||
|
||||
await handshake_task
|
||||
helper.write_packets([(1, b"to device")], True)
|
||||
encrypted_packet = writes.pop()
|
||||
header = encrypted_packet[0:1]
|
||||
assert header == b"\x01"
|
||||
pkg_length_high = encrypted_packet[1]
|
||||
pkg_length_low = encrypted_packet[2]
|
||||
pkg_length = (pkg_length_high << 8) + pkg_length_low
|
||||
assert len(encrypted_packet) == 3 + pkg_length
|
||||
|
||||
encrypted_packet = _make_encrypted_packet_from_encrypted_payload(b"corrupt")
|
||||
mock_data_received(helper, encrypted_packet)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert packets == []
|
||||
assert connection.is_connected is False
|
||||
assert "Invalid encryption key" in caplog.text
|
||||
helper.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
Loading…
Reference in New Issue
Block a user