mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-11 20:01:14 +01:00
Add test for successful noise handshake and single packet (#605)
This commit is contained in:
parent
9f30e9d0df
commit
ae03a831b9
@ -391,3 +391,116 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||||
await handshake_task
|
await handshake_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
|
"""Test the noise frame helper handshake success with a single packet."""
|
||||||
|
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||||
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
|
packets = []
|
||||||
|
writes = []
|
||||||
|
|
||||||
|
def _packet(type_: int, data: bytes):
|
||||||
|
packets.append((type_, data))
|
||||||
|
|
||||||
|
def _writer(data: bytes):
|
||||||
|
writes.append(data)
|
||||||
|
|
||||||
|
def _on_error(exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
helper = MockAPINoiseFrameHelper(
|
||||||
|
on_pkt=_packet,
|
||||||
|
on_error=_on_error,
|
||||||
|
noise_psk=noise_psk,
|
||||||
|
expected_name="servicetest",
|
||||||
|
client_info="my client",
|
||||||
|
log_name="test",
|
||||||
|
)
|
||||||
|
helper._transport = MagicMock()
|
||||||
|
helper._writer = _writer
|
||||||
|
|
||||||
|
proto = NoiseConnection.from_name(
|
||||||
|
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||||
|
)
|
||||||
|
proto.set_as_responder()
|
||||||
|
proto.set_psks(psk_bytes)
|
||||||
|
proto.set_prologue(b"NoiseAPIInit\x00\x00")
|
||||||
|
proto.start_handshake()
|
||||||
|
|
||||||
|
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||||
|
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||||
|
|
||||||
|
assert len(writes) == 1
|
||||||
|
handshake_pkt = writes.pop()
|
||||||
|
|
||||||
|
noise_hello = handshake_pkt[0:3]
|
||||||
|
pkt_header = handshake_pkt[3:6]
|
||||||
|
assert noise_hello == NOISE_HELLO
|
||||||
|
assert pkt_header[0] == 1 # type
|
||||||
|
pkg_length_high = pkt_header[1]
|
||||||
|
pkg_length_low = pkt_header[2]
|
||||||
|
pkg_length = (pkg_length_high << 8) + pkg_length_low
|
||||||
|
assert pkg_length == 49
|
||||||
|
noise_prefix = handshake_pkt[6:7]
|
||||||
|
assert noise_prefix == b"\x00"
|
||||||
|
encrypted_payload = handshake_pkt[7:]
|
||||||
|
|
||||||
|
decrypted = proto.read_message(encrypted_payload)
|
||||||
|
assert decrypted == b""
|
||||||
|
|
||||||
|
hello_pkt = b"\x01servicetest\0"
|
||||||
|
preamble = 1
|
||||||
|
hello_pkg_length = len(hello_pkt)
|
||||||
|
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
|
||||||
|
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||||
|
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||||
|
hello_pkt_with_header = hello_header + hello_pkt
|
||||||
|
helper.data_received(hello_pkt_with_header)
|
||||||
|
|
||||||
|
handshake = proto.write_message(b"")
|
||||||
|
handshake_pkt = b"\x00" + handshake
|
||||||
|
preamble = 1
|
||||||
|
handshake_pkg_length = len(handshake_pkt)
|
||||||
|
handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF
|
||||||
|
handshake_pkg_length_low = handshake_pkg_length & 0xFF
|
||||||
|
handshake_header = bytes(
|
||||||
|
(preamble, handshake_pkg_length_high, handshake_pkg_length_low)
|
||||||
|
)
|
||||||
|
handshake_with_header = handshake_header + handshake_pkt
|
||||||
|
|
||||||
|
helper.data_received(handshake_with_header)
|
||||||
|
|
||||||
|
assert not writes
|
||||||
|
|
||||||
|
await handshake_task
|
||||||
|
helper.write_packet(1, b"to device")
|
||||||
|
encrypted_packet = writes.pop()
|
||||||
|
header = encrypted_packet[0:1]
|
||||||
|
assert header == b"\x01"
|
||||||
|
pkg_length_high = encrypted_packet[1]
|
||||||
|
pkg_length_low = encrypted_packet[2]
|
||||||
|
pkg_length = (pkg_length_high << 8) + pkg_length_low
|
||||||
|
assert len(encrypted_packet) == 3 + pkg_length
|
||||||
|
|
||||||
|
msg_type = 42
|
||||||
|
msg_type_high = (msg_type >> 8) & 0xFF
|
||||||
|
msg_type_low = msg_type & 0xFF
|
||||||
|
msg_length = len(encrypted_payload)
|
||||||
|
msg_length_high = (msg_length >> 8) & 0xFF
|
||||||
|
msg_length_low = msg_length & 0xFF
|
||||||
|
msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low))
|
||||||
|
encrypted_payload = proto.encrypt(msg_header + b"from device")
|
||||||
|
|
||||||
|
preamble = 1
|
||||||
|
encrypted_pkg_length = len(encrypted_payload)
|
||||||
|
encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF
|
||||||
|
encrypted_pkg_length_low = encrypted_pkg_length & 0xFF
|
||||||
|
encrypted_header = bytes(
|
||||||
|
(preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
|
||||||
|
)
|
||||||
|
helper.data_received(encrypted_header + encrypted_payload)
|
||||||
|
|
||||||
|
assert packets == [(42, b"from device")]
|
||||||
|
helper.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user