diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 1bf60ed..092e1b8 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -41,6 +41,16 @@ from .conftest import get_mock_connection_params PREAMBLE = b"\x00" +def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes: + """Make a noise hello packet.""" + preamble = 1 + hello_pkg_length = len(hello_pkt) + hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF + hello_pkg_length_low = hello_pkg_length & 0xFF + hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) + return hello_header + hello_pkt + + def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection: proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND @@ -446,13 +456,7 @@ async def test_noise_frame_helper_handshake_failure(): decrypted = proto.read_message(encrypted_payload) assert decrypted == b"" - hello_pkt = b"\x01servicetest\0" - preamble = 1 - hello_pkg_length = len(hello_pkt) - hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF - hello_pkg_length_low = hello_pkg_length & 0xFF - hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) - hello_pkt_with_header = hello_header + hello_pkt + hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") mock_data_received(helper, hello_pkt_with_header) error_pkt = b"\x01forced to fail" @@ -513,13 +517,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): decrypted = proto.read_message(encrypted_payload) assert decrypted == b"" - hello_pkt = b"\x01servicetest\0" - preamble = 1 - hello_pkg_length = len(hello_pkt) - hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF - hello_pkg_length_low = hello_pkg_length & 0xFF - hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) - hello_pkt_with_header = hello_header + hello_pkt + hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") mock_data_received(helper, hello_pkt_with_header) handshake = proto.write_message(b"") @@ -630,13 +628,7 @@ async def test_noise_frame_helper_empty_hello(): ) handshake_task = asyncio.create_task(helper.perform_handshake(30)) - empty_hello_pkt = b"" - preamble = 1 - hello_pkg_length = len(empty_hello_pkt) - hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF - hello_pkg_length_low = hello_pkg_length & 0xFF - hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) - hello_pkt_with_header = hello_header + empty_hello_pkt + hello_pkt_with_header = _make_noise_hello_pkt(b"") mock_data_received(helper, hello_pkt_with_header) @@ -644,6 +636,30 @@ async def test_noise_frame_helper_empty_hello(): await handshake_task +@pytest.mark.asyncio +async def test_noise_frame_helper_wrong_protocol(): + """Test noise with the wrong protocol.""" + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="servicetest", + client_info="my client", + log_name="test", + ) + + handshake_task = asyncio.create_task(helper.perform_handshake(30)) + # wrong protocol 5 instead of 1 + hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0") + + mock_data_received(helper, hello_pkt_with_header) + + with pytest.raises( + HandshakeAPIError, match="Unknown protocol selected by client 5" + ): + await handshake_task + + @pytest.mark.asyncio async def test_init_noise_attempted_when_esp_uses_plaintext( noise_conn: APIConnection,