diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 8757660..3801849 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import binascii import logging from functools import partial @@ -128,10 +129,10 @@ class APINoiseFrameHelper(APIFrameHelper): exc.__cause__ = original_exc super()._handle_error(exc) - async def perform_handshake(self, timeout: float) -> None: - """Perform the handshake with the server.""" + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Handle a new connection.""" + super().connection_made(transport) self._send_hello_handshake() - await super().perform_handshake(timeout) def data_received(self, data: bytes | bytearray | memoryview) -> None: self._add_to_buffer(data) diff --git a/tests/common.py b/tests/common.py index c7f76ad..dbec784 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,7 @@ from google.protobuf import message from zeroconf import Zeroconf from zeroconf.asyncio import AsyncZeroconf -from aioesphomeapi._frame_helper import APIPlaintextFrameHelper +from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes from aioesphomeapi.api_pb2 import ( ConnectResponse, @@ -31,6 +31,20 @@ utcnow.__doc__ = "Get now in UTC time." PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} +def mock_data_received( + protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes +) -> None: + """Mock data received on the protocol.""" + try: + protocol.data_received(data) + except Exception as err: # pylint: disable=broad-except + loop = asyncio.get_running_loop() + loop.call_soon( + protocol.connection_lost, + err, + ) + + def get_mock_zeroconf() -> MagicMock: with patch("zeroconf.Zeroconf.start"): zc = Zeroconf() diff --git a/tests/conftest.py b/tests/conftest.py index 42585d6..6bd41ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,8 +46,7 @@ def socket_socket(): yield func -@pytest.fixture -def connection_params() -> ConnectionParams: +def get_mock_connection_params() -> ConnectionParams: return ConnectionParams( address="fake.address", port=6052, @@ -60,6 +59,11 @@ def connection_params() -> ConnectionParams: ) +@pytest.fixture +def connection_params() -> ConnectionParams: + return get_mock_connection_params() + + @pytest.fixture def noise_connection_params() -> ConnectionParams: return ConnectionParams( diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 4c024cc..1604b6e 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -4,7 +4,7 @@ import asyncio import base64 from datetime import timedelta from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from noise.connection import NoiseConnection # type: ignore[import-untyped] @@ -30,7 +30,13 @@ from aioesphomeapi.core import ( SocketClosedAPIError, ) -from .common import async_fire_time_changed, get_mock_protocol, utcnow +from .common import ( + async_fire_time_changed, + get_mock_protocol, + mock_data_received, + utcnow, +) +from .conftest import get_mock_connection_params PREAMBLE = b"\x00" @@ -42,18 +48,27 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: class MockConnection(APIConnection): def __init__(self, *args: Any, **kwargs: Any) -> None: """Swallow args.""" + super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs) def process_packet(self, type_: int, data: bytes): packets.append((type_, data)) - def report_fatal_error(self, exc: Exception): - raise exc - connection = MockConnection() return connection, packets class MockAPINoiseFrameHelper(APINoiseFrameHelper): + def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None: + """Swallow args.""" + super().__init__(*args, **kwargs) + transport = MagicMock() + transport.write = writer or MagicMock() + self.__transport = transport + self.connection_made(transport) + + def connection_made(self, transport: Any) -> None: + return super().connection_made(self.__transport) + def mock_write_frame(self, frame: bytes) -> None: """Write a packet to the socket. @@ -125,7 +140,7 @@ def test_plaintext_frame_helper( connection=connection, client_info="my client", log_name="test" ) - helper.data_received(in_bytes) + mock_data_received(helper, in_bytes) pkt = packets.pop() type_, data = pkt @@ -135,7 +150,7 @@ def test_plaintext_frame_helper( # Make sure we correctly handle fragments for i in range(len(in_bytes)): - helper.data_received(in_bytes[i : i + 1]) + mock_data_received(helper, in_bytes[i : i + 1]) pkt = packets.pop() type_, data = pkt @@ -166,7 +181,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None: PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4) ) - helper.data_received(in_bytes) + mock_data_received(helper, in_bytes) pkt = packets.pop() type_, data = pkt @@ -176,7 +191,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None: # Make sure we correctly handle fragments for i in range(len(in_bytes)): - helper.data_received(in_bytes[i : i + 1]) + mock_data_received(helper, in_bytes[i : i + 1]) pkt = packets.pop() type_, data = pkt @@ -215,15 +230,12 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None: client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() for pkt in outgoing_packets: helper.mock_write_frame(byte_type(bytes.fromhex(pkt))) - with pytest.raises(InvalidEncryptionKeyAPIError): - for pkt in incoming_packets: - helper.data_received(byte_type(bytes.fromhex(pkt))) + for pkt in incoming_packets: + mock_data_received(helper, byte_type(bytes.fromhex(pkt))) with pytest.raises(InvalidEncryptionKeyAPIError): await helper.perform_handshake(30) @@ -249,15 +261,12 @@ async def test_noise_frame_helper_incorrect_key(): client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() for pkt in outgoing_packets: helper.mock_write_frame(bytes.fromhex(pkt)) - with pytest.raises(InvalidEncryptionKeyAPIError): - for pkt in incoming_packets: - helper.data_received(bytes.fromhex(pkt)) + for pkt in incoming_packets: + mock_data_received(helper, bytes.fromhex(pkt)) with pytest.raises(InvalidEncryptionKeyAPIError): await helper.perform_handshake(30) @@ -283,17 +292,14 @@ async def test_noise_frame_helper_incorrect_key_fragments(): client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() for pkt in outgoing_packets: helper.mock_write_frame(bytes.fromhex(pkt)) - with pytest.raises(InvalidEncryptionKeyAPIError): - for pkt in incoming_packets: - in_pkt = bytes.fromhex(pkt) - for i in range(len(in_pkt)): - helper.data_received(in_pkt[i : i + 1]) + for pkt in incoming_packets: + in_pkt = bytes.fromhex(pkt) + for i in range(len(in_pkt)): + mock_data_received(helper, in_pkt[i : i + 1]) with pytest.raises(InvalidEncryptionKeyAPIError): await helper.perform_handshake(30) @@ -319,15 +325,12 @@ async def test_noise_incorrect_name(): client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() for pkt in outgoing_packets: helper.mock_write_frame(bytes.fromhex(pkt)) - with pytest.raises(BadNameAPIError): - for pkt in incoming_packets: - helper.data_received(bytes.fromhex(pkt)) + for pkt in incoming_packets: + mock_data_received(helper, bytes.fromhex(pkt)) with pytest.raises(BadNameAPIError): await helper.perform_handshake(30) @@ -350,8 +353,6 @@ async def test_noise_timeout(): client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() for pkt in outgoing_packets: helper.mock_write_frame(bytes.fromhex(pkt)) @@ -408,9 +409,8 @@ async def test_noise_frame_helper_handshake_failure(): expected_name="servicetest", client_info="my client", log_name="test", + writer=_writer, ) - helper._transport = MagicMock() - helper._writer = _writer proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND @@ -448,7 +448,7 @@ async def test_noise_frame_helper_handshake_failure(): hello_pkg_length_low = hello_pkg_length & 0xFF hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_pkt_with_header = hello_header + hello_pkt - helper.data_received(hello_pkt_with_header) + mock_data_received(helper, hello_pkt_with_header) error_pkt = b"\x01forced to fail" preamble = 1 @@ -458,8 +458,7 @@ async def test_noise_frame_helper_handshake_failure(): error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low)) error_pkt_with_header = error_header + error_pkt - with pytest.raises(HandshakeAPIError, match="forced to fail"): - helper.data_received(error_pkt_with_header) + mock_data_received(helper, error_pkt_with_header) with pytest.raises(HandshakeAPIError, match="forced to fail"): await handshake_task @@ -483,9 +482,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): expected_name="servicetest", client_info="my client", log_name="test", + writer=_writer, ) - helper._transport = MagicMock() - helper._writer = _writer proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND @@ -523,7 +521,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): hello_pkg_length_low = hello_pkg_length & 0xFF hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_pkt_with_header = hello_header + hello_pkt - helper.data_received(hello_pkt_with_header) + mock_data_received(helper, hello_pkt_with_header) handshake = proto.write_message(b"") handshake_pkt = b"\x00" + handshake @@ -536,7 +534,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): ) handshake_with_header = handshake_header + handshake_pkt - helper.data_received(handshake_with_header) + mock_data_received(helper, handshake_with_header) assert not writes @@ -566,13 +564,12 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): encrypted_header = bytes( (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) ) - helper.data_received(encrypted_header + encrypted_payload) + mock_data_received(helper, encrypted_header + encrypted_payload) assert packets == [(42, b"from device")] helper.close() - with pytest.raises(ProtocolAPIError, match="Connection closed"): - helper.data_received(encrypted_header + encrypted_payload) + mock_data_received(helper, encrypted_header + encrypted_payload) @pytest.mark.asyncio @@ -590,7 +587,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): task = asyncio.create_task(conn._connect_hello_login(login=True)) await asyncio.sleep(0) # The preamble should be \x00 but we send \x09 - protocol.data_received(b"\x09\x00\x00") + mock_data_received(protocol, b"\x09\x00\x00") with pytest.raises(ProtocolAPIError): await task @@ -615,7 +612,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N assert protocol is not None assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper) - protocol.data_received(b"\x00\x00\x00") + mock_data_received(protocol, b"\x00\x00\x00") with pytest.raises(ProtocolAPIError, match="Marker byte invalid"): await task @@ -632,8 +629,6 @@ async def test_noise_frame_helper_empty_hello(): client_info="my client", log_name="test", ) - helper._transport = MagicMock() - helper._writer = MagicMock() handshake_task = asyncio.create_task(helper.perform_handshake(30)) empty_hello_pkt = b"" @@ -644,8 +639,7 @@ async def test_noise_frame_helper_empty_hello(): hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_pkt_with_header = hello_header + empty_hello_pkt - with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): - helper.data_received(hello_pkt_with_header) + mock_data_received(helper, hello_pkt_with_header) with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): await handshake_task diff --git a/tests/test_client.py b/tests/test_client.py index 8015585..2cc4ff8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -88,7 +88,12 @@ from aioesphomeapi.model import ( ) from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState -from .common import Estr, generate_plaintext_packet, get_mock_zeroconf +from .common import ( + Estr, + generate_plaintext_packet, + get_mock_zeroconf, + mock_data_received, +) @pytest.fixture @@ -849,7 +854,7 @@ async def test_bluetooth_disconnect( response: message.Message = BluetoothDeviceConnectionResponse( address=1234, connected=False ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await disconnect_task @@ -864,7 +869,7 @@ async def test_bluetooth_pair( pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDevicePairingResponse(address=1234) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await pair_task @@ -879,7 +884,7 @@ async def test_bluetooth_unpair( unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDeviceUnpairingResponse(address=1234) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await unpair_task @@ -894,7 +899,7 @@ async def test_bluetooth_clear_cache( clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDeviceClearCacheResponse(address=1234) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await clear_task @@ -914,7 +919,7 @@ async def test_device_info( friendly_name="My Device", has_deep_sleep=True, ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) device_info = await device_info_task assert device_info.name == "realname" assert device_info.friendly_name == "My Device" @@ -923,7 +928,7 @@ async def test_device_info( disconnect_task = asyncio.create_task(client.disconnect()) await asyncio.sleep(0) response: message.Message = DisconnectResponse() - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await disconnect_task with pytest.raises(APIConnectionError, match="CLOSED"): await client.device_info() @@ -943,12 +948,12 @@ async def test_bluetooth_gatt_read( other_response: message.Message = BluetoothGATTReadResponse( address=1234, handle=4567, data=b"4567" ) - protocol.data_received(generate_plaintext_packet(other_response)) + mock_data_received(protocol, generate_plaintext_packet(other_response)) response: message.Message = BluetoothGATTReadResponse( address=1234, handle=1234, data=b"1234" ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert await read_task == b"1234" @@ -966,12 +971,12 @@ async def test_bluetooth_gatt_read_descriptor( other_response: message.Message = BluetoothGATTReadResponse( address=1234, handle=4567, data=b"4567" ) - protocol.data_received(generate_plaintext_packet(other_response)) + mock_data_received(protocol, generate_plaintext_packet(other_response)) response: message.Message = BluetoothGATTReadResponse( address=1234, handle=1234, data=b"1234" ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert await read_task == b"1234" @@ -991,10 +996,10 @@ async def test_bluetooth_gatt_write( other_response: message.Message = BluetoothGATTWriteResponse( address=1234, handle=4567 ) - protocol.data_received(generate_plaintext_packet(other_response)) + mock_data_received(protocol, generate_plaintext_packet(other_response)) response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await write_task @@ -1034,10 +1039,10 @@ async def test_bluetooth_gatt_write_descriptor( other_response: message.Message = BluetoothGATTWriteResponse( address=1234, handle=4567 ) - protocol.data_received(generate_plaintext_packet(other_response)) + mock_data_received(protocol, generate_plaintext_packet(other_response)) response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) await write_task @@ -1077,12 +1082,12 @@ async def test_bluetooth_gatt_read_descriptor( other_response: message.Message = BluetoothGATTReadResponse( address=1234, handle=4567, data=b"4567" ) - protocol.data_received(generate_plaintext_packet(other_response)) + mock_data_received(protocol, generate_plaintext_packet(other_response)) response: message.Message = BluetoothGATTReadResponse( address=1234, handle=1234, data=b"1234" ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert await read_task == b"1234" @@ -1102,9 +1107,9 @@ async def test_bluetooth_gatt_get_services( response: message.Message = BluetoothGATTGetServicesResponse( address=1234, services=[service1] ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234) - protocol.data_received(generate_plaintext_packet(done_response)) + mock_data_received(protocol, generate_plaintext_packet(done_response)) services = await services_task assert services == ESPHomeBluetoothGATTServices( @@ -1129,9 +1134,9 @@ async def test_bluetooth_gatt_get_services_errors( response: message.Message = BluetoothGATTGetServicesResponse( address=1234, services=[service1] ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) done_response: message.Message = BluetoothGATTErrorResponse(address=1234) - protocol.data_received(generate_plaintext_packet(done_response)) + mock_data_received(protocol, generate_plaintext_packet(done_response)) with pytest.raises(BluetoothGATTAPIError): await services_task @@ -1164,9 +1169,10 @@ async def test_bluetooth_gatt_start_notify( data_response: message.Message = BluetoothGATTNotifyDataResponse( address=1234, handle=1, data=b"gotit" ) - protocol.data_received( + mock_data_received( + protocol, generate_plaintext_packet(notify_response) - + generate_plaintext_packet(data_response) + + generate_plaintext_packet(data_response), ) cancel_cb, abort_cb = await notify_task @@ -1175,7 +1181,7 @@ async def test_bluetooth_gatt_start_notify( second_data_response: message.Message = BluetoothGATTNotifyDataResponse( address=1234, handle=1, data=b"after finished" ) - protocol.data_received(generate_plaintext_packet(second_data_response)) + mock_data_received(protocol, generate_plaintext_packet(second_data_response)) assert notifies == [(1, b"gotit"), (1, b"after finished")] await cancel_cb() @@ -1244,7 +1250,7 @@ async def test_subscribe_bluetooth_le_advertisements( manufacturer_data={}, address_type=1, ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert advs == [ BluetoothLEAdvertisement( @@ -1290,7 +1296,7 @@ async def test_subscribe_bluetooth_le_raw_advertisements( ) ] ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert len(adv_groups) == 1 first_adv = adv_groups[0][0] assert first_adv.address == 1234 @@ -1318,7 +1324,7 @@ async def test_subscribe_bluetooth_connections_free( ) await asyncio.sleep(0) response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert connections == [(2, 3)] unsub() @@ -1345,7 +1351,7 @@ async def test_subscribe_home_assistant_states( response: message.Message = SubscribeHomeAssistantStateResponse( entity_id="sensor.red", attribute="any" ) - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert states == [("sensor.red", "any")] diff --git a/tests/test_connection.py b/tests/test_connection.py index 16fc061..4eb313f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -33,6 +33,7 @@ from .common import ( connect, generate_plaintext_packet, get_mock_protocol, + mock_data_received, send_ping_request, send_ping_response, send_plaintext_connect_response, @@ -52,20 +53,22 @@ async def test_connect( ) -> None: """Test that a plaintext connection works.""" conn, transport, protocol, connect_task = plaintext_connect_task_no_login - protocol.data_received( + mock_data_received( + protocol, bytes.fromhex( "003602080110091a216d6173746572617672656c61792028657" "370686f6d652076323032332e362e3329220d6d617374657261" "7672656c6179" - ) + ), ) - protocol.data_received( + mock_data_received( + protocol, bytes.fromhex( "005b0a120d6d6173746572617672656c61791a1130383a33413a" "46323a33453a35453a36302208323032332e362e332a154a756e" "20323820323032332c2031383a31323a3236320965737033322d" "65766250506209457370726573736966" - ) + ), ) await connect_task assert conn.is_connected @@ -80,13 +83,14 @@ async def test_timeout_sending_message( ) -> None: conn, transport, protocol, connect_task = plaintext_connect_task_no_login - protocol.data_received( + mock_data_received( + protocol, b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' b"5stackatomproxy" b"\x00\x00$" b"\x00\x00\x04" b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' - b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" + b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif", ) await connect_task @@ -117,8 +121,9 @@ async def test_disconnect_when_not_fully_connected( # Only send the first part of the handshake # so we are stuck in the middle of the connection process - protocol.data_received( - b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + mock_data_received( + protocol, + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m', ) await asyncio.sleep(0) @@ -156,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection): with pytest.raises(RequiresEncryptionAPIError): task = asyncio.create_task(conn._connect_hello_login(login=True)) await asyncio.sleep(0) - protocol.data_received(b"\x01\x00\x00") + mock_data_received(protocol, b"\x01\x00\x00") await task @@ -175,17 +180,19 @@ async def test_plaintext_connection( messages.append(msg) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) - protocol.data_received( - b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + mock_data_received( + protocol, + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m', ) - protocol.data_received(b"5stackatomproxy") - protocol.data_received(b"\x00\x00$") - protocol.data_received(b"\x00\x00\x04") - protocol.data_received( - b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' + mock_data_received(protocol, b"5stackatomproxy") + mock_data_received(protocol, b"\x00\x00$") + mock_data_received(protocol, b"\x00\x00\x04") + mock_data_received( + protocol, + b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d', ) - protocol.data_received( - b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" + mock_data_received( + protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" ) await asyncio.sleep(0) await connect_task @@ -308,8 +315,9 @@ async def test_finish_connection_times_out( messages.append(msg) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) - protocol.data_received( - b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + mock_data_received( + protocol, + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m', ) await asyncio.sleep(0) @@ -386,17 +394,19 @@ async def test_plaintext_connection_fails_handshake( assert conn._socket is not None assert conn._frame_helper is not None - protocol.data_received( - b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + mock_data_received( + protocol, + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m', ) - protocol.data_received(b"5stackatomproxy") - protocol.data_received(b"\x00\x00$") - protocol.data_received(b"\x00\x00\x04") - protocol.data_received( - b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' + mock_data_received(protocol, b"5stackatomproxy") + mock_data_received(protocol, b"\x00\x00$") + mock_data_received(protocol, b"\x00\x00\x04") + mock_data_received( + protocol, + b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d', ) - protocol.data_received( - b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" + mock_data_received( + protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" ) call_order = [] @@ -530,11 +540,9 @@ async def test_disconnect_fails_to_send_response( await connect_task assert conn.is_connected - with pytest.raises(SocketAPIError), patch.object( - protocol, "_writer", side_effect=OSError - ): + with patch.object(protocol, "_writer", side_effect=OSError): disconnect_request = DisconnectRequest() - protocol.data_received(generate_plaintext_packet(disconnect_request)) + mock_data_received(protocol, generate_plaintext_packet(disconnect_request)) # Wait one loop iteration for the disconnect to be processed await asyncio.sleep(0) @@ -589,7 +597,7 @@ async def test_disconnect_success_case( assert conn.is_connected disconnect_request = DisconnectRequest() - protocol.data_received(generate_plaintext_packet(disconnect_request)) + mock_data_received(protocol, generate_plaintext_packet(disconnect_request)) # Wait one loop iteration for the disconnect to be processed await asyncio.sleep(0) diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index e393675..f03610c 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -17,6 +17,7 @@ from .common import ( Estr, generate_plaintext_packet, get_mock_async_zeroconf, + mock_data_received, send_plaintext_connect_response, send_plaintext_hello, ) @@ -74,11 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec response: message.Message = SubscribeLogsResponse() response.message = b"Hello world" - protocol.data_received(generate_plaintext_packet(response)) + mock_data_received(protocol, generate_plaintext_packet(response)) assert len(messages) == 1 assert messages[0].message == b"Hello world" stop_task = asyncio.create_task(stop()) await asyncio.sleep(0) disconnect_response = DisconnectResponse() - protocol.data_received(generate_plaintext_packet(disconnect_response)) + mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) await stop_task