diff --git a/tests/conftest.py b/tests/conftest.py index 84519b6..d02ef7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ import asyncio from dataclasses import replace from functools import partial import socket -from typing import Any, Callable -from unittest.mock import AsyncMock, MagicMock, create_autospec, patch +from typing import Callable +from unittest.mock import MagicMock, create_autospec, patch import pytest import pytest_asyncio @@ -244,21 +244,3 @@ async def api_client( await connect_task transport.reset_mock() yield client, conn, transport, protocol - - -def make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: - """Make a mock connection.""" - packets: list[tuple[int, bytes]] = [] - - class MockConnection(APIConnection): - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Swallow args.""" - super().__init__( - get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs - ) - - def process_packet(self, type_: int, data: bytes): - packets.append((type_, data)) - - connection = MockConnection() - return connection, packets diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index ebca2b5..cdb8f5e 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio import base64 from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from noise.connection import NoiseConnection # type: ignore[import-untyped] import pytest @@ -27,7 +27,7 @@ from aioesphomeapi.core import ( ) from .common import get_mock_protocol, mock_data_received -from .conftest import make_mock_connection +from .conftest import get_mock_connection_params PREAMBLE = b"\x00" @@ -108,6 +108,24 @@ def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection: return proto +def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: + """Make a mock connection.""" + packets: list[tuple[int, bytes]] = [] + + class MockConnection(APIConnection): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Swallow args.""" + super().__init__( + get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs + ) + + def process_packet(self, type_: int, data: bytes): + packets.append((type_, data)) + + connection = MockConnection() + return connection, packets + + class MockAPINoiseFrameHelper(APINoiseFrameHelper): def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None: """Swallow args.""" @@ -135,6 +153,83 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper): ) from err +@pytest.mark.parametrize( + "in_bytes, pkt_data, pkt_type", + [ + (PREAMBLE + varuint_to_bytes(0) + varuint_to_bytes(1), b"", 1), + ( + PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(1) + (b"\x42" * 192), + (b"\x42" * 192), + 1, + ), + ( + PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(100) + (b"\x42" * 192), + (b"\x42" * 192), + 100, + ), + ( + PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4), + (b"\x42" * 4), + 100, + ), + ( + PREAMBLE + + varuint_to_bytes(8192) + + varuint_to_bytes(8192) + + (b"\x42" * 8192), + (b"\x42" * 8192), + 8192, + ), + ( + PREAMBLE + varuint_to_bytes(256) + varuint_to_bytes(256) + (b"\x42" * 256), + (b"\x42" * 256), + 256, + ), + ( + PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42", + b"\x42", + 32768, + ), + ( + PREAMBLE + + varuint_to_bytes(32768) + + varuint_to_bytes(32768) + + (b"\x42" * 32768), + (b"\x42" * 32768), + 32768, + ), + ], +) +@pytest.mark.asyncio +async def test_plaintext_frame_helper( + in_bytes: bytes, pkt_data: bytes, pkt_type: int +) -> None: + for _ in range(3): + connection, packets = _make_mock_connection() + helper = APIPlaintextFrameHelper( + connection=connection, client_info="my client", log_name="test" + ) + + mock_data_received(helper, in_bytes) + + pkt = packets.pop() + type_, data = pkt + + assert type_ == pkt_type + assert data == pkt_data + + # Make sure we correctly handle fragments + for i in range(len(in_bytes)): + mock_data_received(helper, in_bytes[i : i + 1]) + + pkt = packets.pop() + type_, data = pkt + + assert type_ == pkt_type + assert data == pkt_data + helper.close() + + @pytest.mark.parametrize( "byte_type", (bytes, bytearray, memoryview), @@ -148,7 +243,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None: https://github.com/esphome/issues/issues/5117 """ for _ in range(3): - connection, packets = make_mock_connection() + connection, packets = _make_mock_connection() helper = APIPlaintextFrameHelper( connection=connection, client_info="my client", log_name="test" ) @@ -196,7 +291,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None: "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -227,7 +322,7 @@ async def test_noise_frame_helper_incorrect_key(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -258,7 +353,7 @@ async def test_noise_frame_helper_incorrect_key_fragments(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -291,7 +386,7 @@ async def test_noise_incorrect_name(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -337,7 +432,7 @@ async def test_noise_frame_helper_handshake_failure(): def _writer(data: bytes): writes.append(data) - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -386,7 +481,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): def _writer(data: bytes): writes.append(data) - connection, packets = make_mock_connection() + connection, packets = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -448,7 +543,7 @@ async def test_noise_frame_helper_bad_encryption( def _writer(data: bytes): writes.append(data) - connection, packets = make_mock_connection() + connection, packets = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, @@ -549,7 +644,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N @pytest.mark.asyncio async def test_noise_frame_helper_empty_hello(): """Test empty hello with noise.""" - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", @@ -569,7 +664,7 @@ async def test_noise_frame_helper_empty_hello(): @pytest.mark.asyncio async def test_noise_frame_helper_wrong_protocol(): """Test noise with the wrong protocol.""" - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", @@ -666,7 +761,7 @@ async def test_connection_lost_closes_connection_and_logs( ) async def test_noise_bad_psks(bad_psk: str, error: str) -> None: """Test we raise on bad psks.""" - connection, _ = make_mock_connection() + connection, _ = _make_mock_connection() with pytest.raises(InvalidEncryptionKeyAPIError, match=error): MockAPINoiseFrameHelper( connection=connection,