diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 538a29b..f0aef2e 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -41,6 +41,20 @@ from .conftest import get_mock_connection_params PREAMBLE = b"\x00" +def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes: + noise_hello = handshake_pkt[0:3] + pkt_header = handshake_pkt[3:6] + assert noise_hello == NOISE_HELLO + assert pkt_header[0] == 1 # type + pkg_length_high = pkt_header[1] + pkg_length_low = pkt_header[2] + pkg_length = (pkg_length_high << 8) + pkg_length_low + assert pkg_length == 49 + noise_prefix = handshake_pkt[6:7] + assert noise_prefix == b"\x00" + return handshake_pkt[7:] + + def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes: """Make a noise hello packet.""" preamble = 1 @@ -51,6 +65,42 @@ def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes: return hello_header + hello_pkt +def _make_noise_handshake_pkt(proto: NoiseConnection) -> bytes: + handshake = proto.write_message(b"") + handshake_pkt = b"\x00" + handshake + preamble = 1 + handshake_pkg_length = len(handshake_pkt) + handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF + handshake_pkg_length_low = handshake_pkg_length & 0xFF + handshake_header = bytes( + (preamble, handshake_pkg_length_high, handshake_pkg_length_low) + ) + + return handshake_header + handshake_pkt + + +def _make_encrypted_packet( + proto: NoiseConnection, msg_type: int, encrypted_payload: bytes +) -> bytes: + msg_type = 42 + msg_type_high = (msg_type >> 8) & 0xFF + msg_type_low = msg_type & 0xFF + msg_length = len(encrypted_payload) + msg_length_high = (msg_length >> 8) & 0xFF + msg_length_low = msg_length & 0xFF + msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low)) + encrypted_payload = proto.encrypt(msg_header + b"from device") + + preamble = 1 + encrypted_pkg_length = len(encrypted_payload) + encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF + encrypted_pkg_length_low = encrypted_pkg_length & 0xFF + encrypted_header = bytes( + (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) + ) + return encrypted_header + encrypted_payload + + def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection: proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND @@ -440,18 +490,7 @@ async def test_noise_frame_helper_handshake_failure(): assert len(writes) == 1 handshake_pkt = writes.pop() - - noise_hello = handshake_pkt[0:3] - pkt_header = handshake_pkt[3:6] - assert noise_hello == NOISE_HELLO - assert pkt_header[0] == 1 # type - pkg_length_high = pkt_header[1] - pkg_length_low = pkt_header[2] - pkg_length = (pkg_length_high << 8) + pkg_length_low - assert pkg_length == 49 - noise_prefix = handshake_pkt[6:7] - assert noise_prefix == b"\x00" - encrypted_payload = handshake_pkt[7:] + encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt) decrypted = proto.read_message(encrypted_payload) assert decrypted == b"" @@ -502,35 +541,14 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): assert len(writes) == 1 handshake_pkt = writes.pop() - noise_hello = handshake_pkt[0:3] - pkt_header = handshake_pkt[3:6] - assert noise_hello == NOISE_HELLO - assert pkt_header[0] == 1 # type - pkg_length_high = pkt_header[1] - pkg_length_low = pkt_header[2] - pkg_length = (pkg_length_high << 8) + pkg_length_low - assert pkg_length == 49 - noise_prefix = handshake_pkt[6:7] - assert noise_prefix == b"\x00" - encrypted_payload = handshake_pkt[7:] - + encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt) decrypted = proto.read_message(encrypted_payload) assert decrypted == b"" hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") mock_data_received(helper, hello_pkt_with_header) - handshake = proto.write_message(b"") - handshake_pkt = b"\x00" + handshake - preamble = 1 - handshake_pkg_length = len(handshake_pkt) - handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF - handshake_pkg_length_low = handshake_pkg_length & 0xFF - handshake_header = bytes( - (preamble, handshake_pkg_length_high, handshake_pkg_length_low) - ) - handshake_with_header = handshake_header + handshake_pkt - + handshake_with_header = _make_noise_handshake_pkt(proto) mock_data_received(helper, handshake_with_header) assert not writes @@ -545,28 +563,14 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): pkg_length = (pkg_length_high << 8) + pkg_length_low assert len(encrypted_packet) == 3 + pkg_length - msg_type = 42 - msg_type_high = (msg_type >> 8) & 0xFF - msg_type_low = msg_type & 0xFF - msg_length = len(encrypted_payload) - msg_length_high = (msg_length >> 8) & 0xFF - msg_length_low = msg_length & 0xFF - msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low)) - encrypted_payload = proto.encrypt(msg_header + b"from device") + encrypted_packet = _make_encrypted_packet(proto, 42, b"from device") - preamble = 1 - encrypted_pkg_length = len(encrypted_payload) - encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF - encrypted_pkg_length_low = encrypted_pkg_length & 0xFF - encrypted_header = bytes( - (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) - ) - mock_data_received(helper, encrypted_header + encrypted_payload) + mock_data_received(helper, encrypted_packet) assert packets == [(42, b"from device")] helper.close() - mock_data_received(helper, encrypted_header + encrypted_payload) + mock_data_received(helper, encrypted_packet) @pytest.mark.asyncio