Add test coverage for noise handshake failure (#604)

This commit is contained in:
J. Nick Koston 2023-10-24 14:44:57 -05:00 committed by GitHub
parent cb85c9724f
commit 9f30e9d0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,14 +1,16 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
from datetime import timedelta from datetime import timedelta
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import HandshakeAPIError
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import ( from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint, _cached_bytes_to_varuint as cached_bytes_to_varuint,
@ -19,6 +21,7 @@ from aioesphomeapi._frame_helper.plain_text import (
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from aioesphomeapi.core import ( from aioesphomeapi.core import (
BadNameAPIError, BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError, InvalidEncryptionKeyAPIError,
SocketAPIError, SocketAPIError,
) )
@ -307,3 +310,84 @@ def test_varuint_to_bytes(val, encoded):
def test_bytes_to_varuint(val, encoded): def test_bytes_to_varuint(val, encoded):
assert bytes_to_varuint(encoded) == val assert bytes_to_varuint(encoded) == val
assert cached_bytes_to_varuint(encoded) == val assert cached_bytes_to_varuint(encoded) == val
@pytest.mark.asyncio
async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure."""
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
psk_bytes = base64.b64decode(noise_psk)
packets = []
writes = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _writer(data: bytes):
writes.append(data)
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk=noise_psk,
expected_name="servicetest",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()
helper._writer = _writer
proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
)
proto.set_as_responder()
proto.set_psks(psk_bytes)
proto.set_prologue(b"NoiseAPIInit\x00\x00")
proto.start_handshake()
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
handshake_pkt = writes.pop()
noise_hello = handshake_pkt[0:3]
pkt_header = handshake_pkt[3:6]
assert noise_hello == NOISE_HELLO
assert pkt_header[0] == 1 # type
pkg_length_high = pkt_header[1]
pkg_length_low = pkt_header[2]
pkg_length = (pkg_length_high << 8) + pkg_length_low
assert pkg_length == 49
noise_prefix = handshake_pkt[6:7]
assert noise_prefix == b"\x00"
encrypted_payload = handshake_pkt[7:]
decrypted = proto.read_message(encrypted_payload)
assert decrypted == b""
hello_pkt = b"\x01servicetest\0"
preamble = 1
hello_pkg_length = len(hello_pkt)
hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + hello_pkt
helper.data_received(hello_pkt_with_header)
error_pkt = b"\x01forced to fail"
preamble = 1
error_pkg_length = len(error_pkt)
error_pkg_length_high = (error_pkg_length >> 8) & 0xFF
error_pkg_length_low = error_pkg_length & 0xFF
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
error_pkt_with_header = error_header + error_pkt
with pytest.raises(HandshakeAPIError, match="forced to fail"):
helper.data_received(error_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="forced to fail"):
await handshake_task