From 3ccb36b6fc8fbe43c93e91cb3ec03467f0b3e0aa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Nov 2023 12:24:50 -0600 Subject: [PATCH] Refactor frame helper to avoid py conversions when processing packets (#641) --- aioesphomeapi/_frame_helper/base.pxd | 13 +-- aioesphomeapi/_frame_helper/base.py | 14 ++-- aioesphomeapi/_frame_helper/noise.pxd | 1 + aioesphomeapi/_frame_helper/noise.py | 10 ++- aioesphomeapi/_frame_helper/plain_text.pxd | 1 + aioesphomeapi/_frame_helper/plain_text.py | 2 +- aioesphomeapi/connection.pxd | 4 +- aioesphomeapi/connection.py | 20 +++-- tests/test__frame_helper.py | 94 ++++++++-------------- tests/test_connection.py | 3 +- 10 files changed, 69 insertions(+), 93 deletions(-) diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 78ad9c6..5933f48 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -1,30 +1,31 @@ import cython +from ..connection cimport APIConnection + cdef bint TYPE_CHECKING cdef class APIFrameHelper: cdef object _loop - cdef object _on_pkt - cdef object _on_error + cdef APIConnection _connection cdef object _transport cdef public object _writer cdef public object _ready_future cdef bytes _buffer - cdef cython.uint _buffer_len - cdef cython.uint _pos + cdef unsigned int _buffer_len + cdef unsigned int _pos cdef object _client_info cdef str _log_name cdef object _debug_enabled - @cython.locals(original_pos=cython.uint, new_pos=cython.uint) + @cython.locals(original_pos="unsigned int", new_pos="unsigned int") cdef bytes _read_exactly(self, int length) cdef _add_to_buffer(self, bytes data) - @cython.locals(end_of_frame_pos=cython.uint) + @cython.locals(end_of_frame_pos="unsigned int") cdef _remove_from_buffer(self) cpdef write_packets(self, list packets) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index fa94759..8e4aad1 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, Callable, cast from ..core import HandshakeAPIError, SocketClosedAPIError +if TYPE_CHECKING: + from ..connection import APIConnection + _LOGGER = logging.getLogger(__name__) SOCKET_ERRORS = ( @@ -27,8 +30,7 @@ class APIFrameHelper: __slots__ = ( "_loop", - "_on_pkt", - "_on_error", + "_connection", "_transport", "_writer", "_ready_future", @@ -42,16 +44,14 @@ class APIFrameHelper: def __init__( self, - on_pkt: Callable[[int, bytes], None], - on_error: Callable[[Exception], None], + connection: "APIConnection", client_info: str, log_name: str, ) -> None: """Initialize the API frame helper.""" loop = asyncio.get_event_loop() self._loop = loop - self._on_pkt = on_pkt - self._on_error = on_error + self._connection = connection self._transport: asyncio.Transport | None = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._ready_future = self._loop.create_future() @@ -143,7 +143,7 @@ class APIFrameHelper: self.close() def _handle_error(self, exc: Exception) -> None: - self._on_error(exc) + self._connection.report_fatal_error(exc) def connection_lost(self, exc: Exception | None) -> None: """Handle the connection being lost.""" diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index 674ac18..5ac7532 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -1,5 +1,6 @@ import cython +from ..connection cimport APIConnection from .base cimport APIFrameHelper diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 79ea85f..e8da907 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -25,6 +25,9 @@ from ..core import ( ) from .base import WRITE_EXCEPTIONS, APIFrameHelper +if TYPE_CHECKING: + from ..connection import APIConnection + _LOGGER = logging.getLogger(__name__) @@ -81,15 +84,14 @@ class APINoiseFrameHelper(APIFrameHelper): def __init__( self, - on_pkt: Callable[[int, bytes], None], - on_error: Callable[[Exception], None], + connection: "APIConnection", noise_psk: str, expected_name: str | None, client_info: str, log_name: str, ) -> None: """Initialize the API frame helper.""" - super().__init__(on_pkt, on_error, client_info, log_name) + super().__init__(connection, client_info, log_name) self._noise_psk = noise_psk self._expected_name = expected_name self._set_state(NoiseConnectionState.HELLO) @@ -364,7 +366,7 @@ class APINoiseFrameHelper(APIFrameHelper): # N bytes: message data type_high = msg[0] type_low = msg[1] - self._on_pkt((type_high << 8) | type_low, msg[4:]) + self._connection.process_packet((type_high << 8) | type_low, msg[4:]) def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument """Handle a closed frame.""" diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index 164e710..59896a9 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -1,5 +1,6 @@ import cython +from ..connection cimport APIConnection from .base cimport APIFrameHelper diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 43bdf2c..79e02b3 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -166,7 +166,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): packet_data = maybe_packet_data self._remove_from_buffer() - self._on_pkt(msg_type_int, packet_data) + self._connection.process_packet(msg_type_int, packet_data) # If we have more data, continue processing def _error_on_incorrect_preamble(self, preamble: _int) -> None: diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index e347349..82600d7 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -74,7 +74,7 @@ cdef class APIConnection: cdef send_messages(self, tuple messages) @cython.locals(handlers=set, handlers_copy=set) - cpdef _process_packet(self, object msg_type_proto, object data) + cpdef process_packet(self, object msg_type_proto, object data) cpdef _async_cancel_pong_timer(self) @@ -84,7 +84,7 @@ cdef class APIConnection: cpdef _set_connection_state(self, object state) - cpdef _report_fatal_error(self, Exception err) + cpdef report_fatal_error(self, Exception err) @cython.locals(handlers=set) cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 7d53fa2..bfb0eb6 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -337,8 +337,7 @@ class APIConnection: if (noise_psk := self._params.noise_psk) is None: _, fh = await loop.create_connection( # type: ignore[type-var] lambda: APIPlaintextFrameHelper( - on_pkt=self._process_packet, - on_error=self._report_fatal_error, + connection=self, client_info=self._params.client_info, log_name=self.log_name, ), @@ -349,8 +348,7 @@ class APIConnection: lambda: APINoiseFrameHelper( noise_psk=noise_psk, expected_name=self._params.expected_name, - on_pkt=self._process_packet, - on_error=self._report_fatal_error, + connection=self, client_info=self._params.client_info, log_name=self.log_name, ), @@ -395,7 +393,7 @@ class APIConnection: CONNECT_REQUEST_TIMEOUT, ) except TimeoutAPIError as err: - self._report_fatal_error(err) + self.report_fatal_error(err) raise TimeoutAPIError("Hello timed out") from err resp = responses.pop(0) @@ -499,7 +497,7 @@ class APIConnection: self.log_name, self._keep_alive_timeout, ) - self._report_fatal_error( + self.report_fatal_error( PingFailedAPIError( f"Ping response not received after {self._keep_alive_timeout} seconds" ) @@ -653,7 +651,7 @@ class APIConnection: # If writing packet fails, we don't know what state the frames # are in anymore and we have to close the connection _LOGGER.info("%s: Error writing packets: %s", self.log_name, err) - self._report_fatal_error(err) + self.report_fatal_error(err) raise def _add_message_callback_without_remove( @@ -786,7 +784,7 @@ class APIConnection: ) return response - def _report_fatal_error(self, err: Exception) -> None: + def report_fatal_error(self, err: Exception) -> None: """Report a fatal error that occurred during an operation. This should only be called for errors that mean the connection @@ -806,8 +804,8 @@ class APIConnection: self._fatal_exception = err self._cleanup() - def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None: - """Factory to make a packet processor.""" + def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: + """Process an incoming packet.""" if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: _LOGGER.debug( "%s: Skipping message type %s", @@ -831,7 +829,7 @@ class APIConnection: e, exc_info=True, ) - self._report_fatal_error( + self.report_fatal_error( ProtocolAPIError( f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" ) diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 8217bb7..7d2dcdb 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -3,11 +3,13 @@ from __future__ import annotations import asyncio import base64 from datetime import timedelta +from typing import Any from unittest.mock import MagicMock import pytest from noise.connection import NoiseConnection # type: ignore[import-untyped] +from aioesphomeapi import APIConnection from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO @@ -31,6 +33,24 @@ from .common import async_fire_time_changed, utcnow PREAMBLE = b"\x00" +def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: + """Make a mock connection.""" + packets: list[tuple[int, bytes]] = [] + + class MockConnection(APIConnection): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Swallow args.""" + + def process_packet(self, type_: int, data: bytes): + packets.append((type_, data)) + + def report_fatal_error(self, exc: Exception): + raise exc + + connection = MockConnection() + return connection, packets + + class MockAPINoiseFrameHelper(APINoiseFrameHelper): def mock_write_frame(self, frame: bytes) -> None: """Write a packet to the socket. @@ -97,16 +117,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper): ) async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): for _ in range(3): - packets = [] - - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - - def _on_error(exc: Exception): - raise exc - + connection, packets = _make_mock_connection() helper = APIPlaintextFrameHelper( - on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test" + connection=connection, client_info="my client", log_name="test" ) helper.data_received(in_bytes) @@ -139,17 +152,10 @@ async def test_noise_frame_helper_incorrect_key(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - packets = [] - - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - - def _on_error(exc: Exception): - raise exc + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="servicetest", client_info="my client", @@ -180,17 +186,10 @@ async def test_noise_frame_helper_incorrect_key_fragments(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - packets = [] - - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - - def _on_error(exc: Exception): - raise exc + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="servicetest", client_info="my client", @@ -223,17 +222,10 @@ async def test_noise_incorrect_name(): "01000d01736572766963657465737400", "0100160148616e647368616b65204d4143206661696c757265", ] - packets = [] - - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - - def _on_error(exc: Exception): - raise exc + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="wrongname", client_info="my client", @@ -260,17 +252,11 @@ async def test_noise_timeout(): "010000", # hello packet "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", ] - packets = [] - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - - def _on_error(exc: Exception): - raise exc + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="wrongname", client_info="my client", @@ -317,21 +303,15 @@ async def test_noise_frame_helper_handshake_failure(): """Test the noise frame helper handshake failure.""" noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" psk_bytes = base64.b64decode(noise_psk) - packets = [] writes = [] - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - def _writer(data: bytes): writes.append(data) - def _on_error(exc: Exception): - raise exc + connection, _ = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk=noise_psk, expected_name="servicetest", client_info="my client", @@ -398,21 +378,15 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): """Test the noise frame helper handshake success with a single packet.""" noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" psk_bytes = base64.b64decode(noise_psk) - packets = [] writes = [] - def _packet(type_: int, data: bytes): - packets.append((type_, data)) - def _writer(data: bytes): writes.append(data) - def _on_error(exc: Exception): - raise exc + connection, packets = _make_mock_connection() helper = MockAPINoiseFrameHelper( - on_pkt=_packet, - on_error=_on_error, + connection=connection, noise_psk=noise_psk, expected_name="servicetest", client_info="my client", diff --git a/tests/test_connection.py b/tests/test_connection.py index 918088e..85967a7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -35,8 +35,7 @@ from .common import ( def _get_mock_protocol(conn: APIConnection): protocol = APIPlaintextFrameHelper( - on_pkt=conn._process_packet, - on_error=conn._report_fatal_error, + connection=conn, client_info="mock", log_name="mock_device", )