Use slots for APIConnection and APIClient (#453)

This commit is contained in:
J. Nick Koston 2023-07-01 16:31:58 -05:00 committed by GitHub
parent 45ded8590d
commit 34f6badcde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 11 deletions

View File

@ -206,6 +206,13 @@ ExecuteServiceDataType = Dict[
# pylint: disable=too-many-public-methods
class APIClient:
__slots__ = (
"_params",
"_connection",
"_cached_name",
"_background_tasks",
)
def __init__(
self,
address: str,

View File

@ -115,6 +115,28 @@ class APIConnection:
a new instance should be established.
"""
__slots__ = (
"_params",
"on_stop",
"_on_stop_task",
"_socket",
"_frame_helper",
"_api_version",
"_connection_state",
"_is_authenticated",
"_connect_complete",
"_message_handlers",
"log_name",
"_read_exception_handlers",
"_ping_timer",
"_pong_timer",
"_keep_alive_interval",
"_keep_alive_timeout",
"_connect_task",
"_fatal_exception",
"_expected_disconnect",
)
def __init__(
self,
params: ConnectionParams,

View File

@ -1,5 +1,6 @@
import asyncio
import socket
from typing import Optional
import pytest
from mock import MagicMock, patch
@ -63,13 +64,39 @@ def _get_mock_protocol(conn: APIConnection):
@pytest.mark.asyncio
async def test_connect(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", return_value=(MagicMock(), protocol)
), patch.object(conn, "_connect_start_ping"), patch.object(
conn, "send_message_await_response", return_value=HelloResponse()
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
await conn.connect(login=False)
connect_task = asyncio.create_task(conn.connect(login=False))
await connected.wait()
protocol.data_received(
bytes.fromhex(
"003602080110091a216d6173746572617672656c61792028657"
"370686f6d652076323032332e362e3329220d6d617374657261"
"7672656c6179"
)
)
protocol.data_received(
bytes.fromhex(
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
"46323a33453a35453a36302208323032332e362e332a154a756e"
"20323820323032332c2031383a31323a3236320965737033322d"
"65766250506209457370726573736966"
)
)
await connect_task
assert conn.is_connected
@ -99,6 +126,16 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
@ -106,13 +143,11 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
transport = MagicMock()
with patch.object(conn, "_connect_hello"), patch.object(
loop, "sock_connect"
), patch.object(loop, "create_connection") as create_connection, patch.object(
protocol, "perform_handshake"
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
create_connection.return_value = (transport, protocol)
await conn.connect(login=False)
connect_task = asyncio.create_task(conn.connect(login=False))
await connected.wait()
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
@ -127,6 +162,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
await asyncio.sleep(0)
await connect_task
assert conn.is_connected
assert len(messages) == 2
assert isinstance(messages[0], HelloResponse)