diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 32abe56..d19e08e 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -206,6 +206,13 @@ ExecuteServiceDataType = Dict[ # pylint: disable=too-many-public-methods class APIClient: + __slots__ = ( + "_params", + "_connection", + "_cached_name", + "_background_tasks", + ) + def __init__( self, address: str, diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 9b97144..c6e550b 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -115,6 +115,28 @@ class APIConnection: a new instance should be established. """ + __slots__ = ( + "_params", + "on_stop", + "_on_stop_task", + "_socket", + "_frame_helper", + "_api_version", + "_connection_state", + "_is_authenticated", + "_connect_complete", + "_message_handlers", + "log_name", + "_read_exception_handlers", + "_ping_timer", + "_pong_timer", + "_keep_alive_interval", + "_keep_alive_timeout", + "_connect_task", + "_fatal_exception", + "_expected_disconnect", + ) + def __init__( self, params: ConnectionParams, diff --git a/tests/test_connection.py b/tests/test_connection.py index 38f3444..1dacdcd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,6 @@ import asyncio import socket +from typing import Optional import pytest from mock import MagicMock, patch @@ -63,13 +64,39 @@ def _get_mock_protocol(conn: APIConnection): @pytest.mark.asyncio async def test_connect(conn, resolve_host, socket_socket, event_loop): loop = asyncio.get_event_loop() - protocol = _get_mock_protocol(conn) + protocol: Optional[APIPlaintextFrameHelper] = None + transport = MagicMock() + connected = asyncio.Event() + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", return_value=(MagicMock(), protocol) - ), patch.object(conn, "_connect_start_ping"), patch.object( - conn, "send_message_await_response", return_value=HelloResponse() + loop, "create_connection", side_effect=_create_mock_transport_protocol ): - await conn.connect(login=False) + connect_task = asyncio.create_task(conn.connect(login=False)) + await connected.wait() + protocol.data_received( + bytes.fromhex( + "003602080110091a216d6173746572617672656c61792028657" + "370686f6d652076323032332e362e3329220d6d617374657261" + "7672656c6179" + ) + ) + protocol.data_received( + bytes.fromhex( + "005b0a120d6d6173746572617672656c61791a1130383a33413a" + "46323a33453a35453a36302208323032332e362e332a154a756e" + "20323820323032332c2031383a31323a3236320965737033322d" + "65766250506209457370726573736966" + ) + ) + + await connect_task assert conn.is_connected @@ -99,6 +126,16 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so loop = asyncio.get_event_loop() protocol = _get_mock_protocol(conn) messages = [] + protocol: Optional[APIPlaintextFrameHelper] = None + transport = MagicMock() + connected = asyncio.Event() + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol def on_msg(msg): messages.append(msg) @@ -106,13 +143,11 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse}) transport = MagicMock() - with patch.object(conn, "_connect_hello"), patch.object( - loop, "sock_connect" - ), patch.object(loop, "create_connection") as create_connection, patch.object( - protocol, "perform_handshake" + with patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol ): - create_connection.return_value = (transport, protocol) - await conn.connect(login=False) + connect_task = asyncio.create_task(conn.connect(login=False)) + await connected.wait() protocol.data_received( b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' @@ -127,6 +162,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" ) await asyncio.sleep(0) + await connect_task assert conn.is_connected assert len(messages) == 2 assert isinstance(messages[0], HelloResponse)