Add test for wrong protocol in noise (#687)

This commit is contained in:
J. Nick Koston 2023-11-24 09:57:31 -06:00 committed by GitHub
parent 6453aa87f6
commit 837d6ad650
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,6 +41,16 @@ from .conftest import get_mock_connection_params
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes:
"""Make a noise hello packet."""
preamble = 1
hello_pkg_length = len(hello_pkt)
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
return hello_header + hello_pkt
def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection: def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection:
proto = NoiseConnection.from_name( proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
@ -446,13 +456,7 @@ async def test_noise_frame_helper_handshake_failure():
decrypted = proto.read_message(encrypted_payload) decrypted = proto.read_message(encrypted_payload)
assert decrypted == b"" assert decrypted == b""
hello_pkt = b"\x01servicetest\0" hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
preamble = 1
hello_pkg_length = len(hello_pkt)
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + hello_pkt
mock_data_received(helper, hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
error_pkt = b"\x01forced to fail" error_pkt = b"\x01forced to fail"
@ -513,13 +517,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
decrypted = proto.read_message(encrypted_payload) decrypted = proto.read_message(encrypted_payload)
assert decrypted == b"" assert decrypted == b""
hello_pkt = b"\x01servicetest\0" hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
preamble = 1
hello_pkg_length = len(hello_pkt)
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + hello_pkt
mock_data_received(helper, hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
handshake = proto.write_message(b"") handshake = proto.write_message(b"")
@ -630,13 +628,7 @@ async def test_noise_frame_helper_empty_hello():
) )
handshake_task = asyncio.create_task(helper.perform_handshake(30)) handshake_task = asyncio.create_task(helper.perform_handshake(30))
empty_hello_pkt = b"" hello_pkt_with_header = _make_noise_hello_pkt(b"")
preamble = 1
hello_pkg_length = len(empty_hello_pkt)
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + empty_hello_pkt
mock_data_received(helper, hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
@ -644,6 +636,30 @@ async def test_noise_frame_helper_empty_hello():
await handshake_task await handshake_task
@pytest.mark.asyncio
async def test_noise_frame_helper_wrong_protocol():
"""Test noise with the wrong protocol."""
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
# wrong protocol 5 instead of 1
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
mock_data_received(helper, hello_pkt_with_header)
with pytest.raises(
HandshakeAPIError, match="Unknown protocol selected by client 5"
):
await handshake_task
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_noise_attempted_when_esp_uses_plaintext( async def test_init_noise_attempted_when_esp_uses_plaintext(
noise_conn: APIConnection, noise_conn: APIConnection,