mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Add test coverage for noise handshake failure (#604)
This commit is contained in:
parent
cb85c9724f
commit
9f30e9d0df
@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
|
||||
from aioesphomeapi import HandshakeAPIError
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||
from aioesphomeapi._frame_helper.plain_text import (
|
||||
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
||||
@ -19,6 +21,7 @@ from aioesphomeapi._frame_helper.plain_text import (
|
||||
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
|
||||
from aioesphomeapi.core import (
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
@ -307,3 +310,84 @@ def test_varuint_to_bytes(val, encoded):
|
||||
def test_bytes_to_varuint(val, encoded):
|
||||
assert bytes_to_varuint(encoded) == val
|
||||
assert cached_bytes_to_varuint(encoded) == val
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_handshake_failure():
|
||||
"""Test the noise frame helper handshake failure."""
|
||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
packets = []
|
||||
writes = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk=noise_psk,
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = _writer
|
||||
|
||||
proto = NoiseConnection.from_name(
|
||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||
)
|
||||
proto.set_as_responder()
|
||||
proto.set_psks(psk_bytes)
|
||||
proto.set_prologue(b"NoiseAPIInit\x00\x00")
|
||||
proto.start_handshake()
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
handshake_pkt = writes.pop()
|
||||
|
||||
noise_hello = handshake_pkt[0:3]
|
||||
pkt_header = handshake_pkt[3:6]
|
||||
assert noise_hello == NOISE_HELLO
|
||||
assert pkt_header[0] == 1 # type
|
||||
pkg_length_high = pkt_header[1]
|
||||
pkg_length_low = pkt_header[2]
|
||||
pkg_length = (pkg_length_high << 8) + pkg_length_low
|
||||
assert pkg_length == 49
|
||||
noise_prefix = handshake_pkt[6:7]
|
||||
assert noise_prefix == b"\x00"
|
||||
encrypted_payload = handshake_pkt[7:]
|
||||
|
||||
decrypted = proto.read_message(encrypted_payload)
|
||||
assert decrypted == b""
|
||||
|
||||
hello_pkt = b"\x01servicetest\0"
|
||||
preamble = 1
|
||||
hello_pkg_length = len(hello_pkt)
|
||||
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
|
||||
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||
hello_pkt_with_header = hello_header + hello_pkt
|
||||
helper.data_received(hello_pkt_with_header)
|
||||
|
||||
error_pkt = b"\x01forced to fail"
|
||||
preamble = 1
|
||||
error_pkg_length = len(error_pkt)
|
||||
error_pkg_length_high = (error_pkg_length >> 8) & 0xFF
|
||||
error_pkg_length_low = error_pkg_length & 0xFF
|
||||
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
|
||||
error_pkt_with_header = error_header + error_pkt
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||
helper.data_received(error_pkt_with_header)
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||
await handshake_task
|
||||
|
Loading…
Reference in New Issue
Block a user