diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index f0aef2e..c0cbf56 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -80,17 +80,20 @@ def _make_noise_handshake_pkt(proto: NoiseConnection) -> bytes: def _make_encrypted_packet( - proto: NoiseConnection, msg_type: int, encrypted_payload: bytes + proto: NoiseConnection, msg_type: int, payload: bytes ) -> bytes: msg_type = 42 msg_type_high = (msg_type >> 8) & 0xFF msg_type_low = msg_type & 0xFF - msg_length = len(encrypted_payload) + msg_length = len(payload) msg_length_high = (msg_length >> 8) & 0xFF msg_length_low = msg_length & 0xFF msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low)) - encrypted_payload = proto.encrypt(msg_header + b"from device") + encrypted_payload = proto.encrypt(msg_header + payload) + return _make_encrypted_packet_from_encrypted_payload(encrypted_payload) + +def _make_encrypted_packet_from_encrypted_payload(encrypted_payload: bytes) -> bytes: preamble = 1 encrypted_pkg_length = len(encrypted_payload) encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF @@ -573,6 +576,69 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): mock_data_received(helper, encrypted_packet) +@pytest.mark.asyncio +async def test_noise_frame_helper_bad_encryption( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the noise frame helper closes connection on encryption error.""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + psk_bytes = base64.b64decode(noise_psk) + writes = [] + + def _writer(data: bytes): + writes.append(data) + + connection, packets = _make_mock_connection() + + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk=noise_psk, + expected_name="servicetest", + client_info="my client", + log_name="test", + writer=_writer, + ) + + proto = _mock_responder_proto(psk_bytes) + + handshake_task = asyncio.create_task(helper.perform_handshake(30)) + await asyncio.sleep(0) # let the task run to read the hello packet + + assert len(writes) == 1 + handshake_pkt = writes.pop() + + encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt) + decrypted = proto.read_message(encrypted_payload) + assert decrypted == b"" + + hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") + mock_data_received(helper, hello_pkt_with_header) + + handshake_with_header = _make_noise_handshake_pkt(proto) + mock_data_received(helper, handshake_with_header) + + assert not writes + + await handshake_task + helper.write_packets([(1, b"to device")], True) + encrypted_packet = writes.pop() + header = encrypted_packet[0:1] + assert header == b"\x01" + pkg_length_high = encrypted_packet[1] + pkg_length_low = encrypted_packet[2] + pkg_length = (pkg_length_high << 8) + pkg_length_low + assert len(encrypted_packet) == 3 + pkg_length + + encrypted_packet = _make_encrypted_packet_from_encrypted_payload(b"corrupt") + mock_data_received(helper, encrypted_packet) + await asyncio.sleep(0) + + assert packets == [] + assert connection.is_connected is False + assert "Invalid encryption key" in caplog.text + helper.close() + + @pytest.mark.asyncio async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): loop = asyncio.get_event_loop()