diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index ccf3e2c..18dbd35 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -391,3 +391,116 @@ async def test_noise_frame_helper_handshake_failure(): with pytest.raises(HandshakeAPIError, match="forced to fail"): await handshake_task + + +@pytest.mark.asyncio +async def test_noise_frame_helper_handshake_success_with_single_packet(): + """Test the noise frame helper handshake success with a single packet.""" + noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" + psk_bytes = base64.b64decode(noise_psk) + packets = [] + writes = [] + + def _packet(type_: int, data: bytes): + packets.append((type_, data)) + + def _writer(data: bytes): + writes.append(data) + + def _on_error(exc: Exception): + raise exc + + helper = MockAPINoiseFrameHelper( + on_pkt=_packet, + on_error=_on_error, + noise_psk=noise_psk, + expected_name="servicetest", + client_info="my client", + log_name="test", + ) + helper._transport = MagicMock() + helper._writer = _writer + + proto = NoiseConnection.from_name( + b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND + ) + proto.set_as_responder() + proto.set_psks(psk_bytes) + proto.set_prologue(b"NoiseAPIInit\x00\x00") + proto.start_handshake() + + handshake_task = asyncio.create_task(helper.perform_handshake(30)) + await asyncio.sleep(0) # let the task run to read the hello packet + + assert len(writes) == 1 + handshake_pkt = writes.pop() + + noise_hello = handshake_pkt[0:3] + pkt_header = handshake_pkt[3:6] + assert noise_hello == NOISE_HELLO + assert pkt_header[0] == 1 # type + pkg_length_high = pkt_header[1] + pkg_length_low = pkt_header[2] + pkg_length = (pkg_length_high << 8) + pkg_length_low + assert pkg_length == 49 + noise_prefix = handshake_pkt[6:7] + assert noise_prefix == b"\x00" + encrypted_payload = handshake_pkt[7:] + + decrypted = proto.read_message(encrypted_payload) + assert decrypted == b"" + + hello_pkt = b"\x01servicetest\0" + preamble = 1 + hello_pkg_length = len(hello_pkt) + hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF + hello_pkg_length_low = hello_pkg_length & 0xFF + hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) + hello_pkt_with_header = hello_header + hello_pkt + helper.data_received(hello_pkt_with_header) + + handshake = proto.write_message(b"") + handshake_pkt = b"\x00" + handshake + preamble = 1 + handshake_pkg_length = len(handshake_pkt) + handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF + handshake_pkg_length_low = handshake_pkg_length & 0xFF + handshake_header = bytes( + (preamble, handshake_pkg_length_high, handshake_pkg_length_low) + ) + handshake_with_header = handshake_header + handshake_pkt + + helper.data_received(handshake_with_header) + + assert not writes + + await handshake_task + helper.write_packet(1, b"to device") + encrypted_packet = writes.pop() + header = encrypted_packet[0:1] + assert header == b"\x01" + pkg_length_high = encrypted_packet[1] + pkg_length_low = encrypted_packet[2] + pkg_length = (pkg_length_high << 8) + pkg_length_low + assert len(encrypted_packet) == 3 + pkg_length + + msg_type = 42 + msg_type_high = (msg_type >> 8) & 0xFF + msg_type_low = msg_type & 0xFF + msg_length = len(encrypted_payload) + msg_length_high = (msg_length >> 8) & 0xFF + msg_length_low = msg_length & 0xFF + msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low)) + encrypted_payload = proto.encrypt(msg_header + b"from device") + + preamble = 1 + encrypted_pkg_length = len(encrypted_payload) + encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF + encrypted_pkg_length_low = encrypted_pkg_length & 0xFF + encrypted_header = bytes( + (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) + ) + helper.data_received(encrypted_header + encrypted_payload) + + assert packets == [(42, b"from device")] + helper.close()