Refactor frame helper to avoid py conversions when processing packets (#641)

This commit is contained in:
J. Nick Koston 2023-11-16 12:24:50 -06:00 committed by GitHub
parent 371143d383
commit 3ccb36b6fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 69 additions and 93 deletions

View File

@ -1,30 +1,31 @@
import cython import cython
from ..connection cimport APIConnection
cdef bint TYPE_CHECKING cdef bint TYPE_CHECKING
cdef class APIFrameHelper: cdef class APIFrameHelper:
cdef object _loop cdef object _loop
cdef object _on_pkt cdef APIConnection _connection
cdef object _on_error
cdef object _transport cdef object _transport
cdef public object _writer cdef public object _writer
cdef public object _ready_future cdef public object _ready_future
cdef bytes _buffer cdef bytes _buffer
cdef cython.uint _buffer_len cdef unsigned int _buffer_len
cdef cython.uint _pos cdef unsigned int _pos
cdef object _client_info cdef object _client_info
cdef str _log_name cdef str _log_name
cdef object _debug_enabled cdef object _debug_enabled
@cython.locals(original_pos=cython.uint, new_pos=cython.uint) @cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length) cdef bytes _read_exactly(self, int length)
cdef _add_to_buffer(self, bytes data) cdef _add_to_buffer(self, bytes data)
@cython.locals(end_of_frame_pos=cython.uint) @cython.locals(end_of_frame_pos="unsigned int")
cdef _remove_from_buffer(self) cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets) cpdef write_packets(self, list packets)

View File

@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError from ..core import HandshakeAPIError, SocketClosedAPIError
if TYPE_CHECKING:
from ..connection import APIConnection
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SOCKET_ERRORS = ( SOCKET_ERRORS = (
@ -27,8 +30,7 @@ class APIFrameHelper:
__slots__ = ( __slots__ = (
"_loop", "_loop",
"_on_pkt", "_connection",
"_on_error",
"_transport", "_transport",
"_writer", "_writer",
"_ready_future", "_ready_future",
@ -42,16 +44,14 @@ class APIFrameHelper:
def __init__( def __init__(
self, self,
on_pkt: Callable[[int, bytes], None], connection: "APIConnection",
on_error: Callable[[Exception], None],
client_info: str, client_info: str,
log_name: str, log_name: str,
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self._loop = loop self._loop = loop
self._on_pkt = on_pkt self._connection = connection
self._on_error = on_error
self._transport: asyncio.Transport | None = None self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future() self._ready_future = self._loop.create_future()
@ -143,7 +143,7 @@ class APIFrameHelper:
self.close() self.close()
def _handle_error(self, exc: Exception) -> None: def _handle_error(self, exc: Exception) -> None:
self._on_error(exc) self._connection.report_fatal_error(exc)
def connection_lost(self, exc: Exception | None) -> None: def connection_lost(self, exc: Exception | None) -> None:
"""Handle the connection being lost.""" """Handle the connection being lost."""

View File

@ -1,5 +1,6 @@
import cython import cython
from ..connection cimport APIConnection
from .base cimport APIFrameHelper from .base cimport APIFrameHelper

View File

@ -25,6 +25,9 @@ from ..core import (
) )
from .base import WRITE_EXCEPTIONS, APIFrameHelper from .base import WRITE_EXCEPTIONS, APIFrameHelper
if TYPE_CHECKING:
from ..connection import APIConnection
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -81,15 +84,14 @@ class APINoiseFrameHelper(APIFrameHelper):
def __init__( def __init__(
self, self,
on_pkt: Callable[[int, bytes], None], connection: "APIConnection",
on_error: Callable[[Exception], None],
noise_psk: str, noise_psk: str,
expected_name: str | None, expected_name: str | None,
client_info: str, client_info: str,
log_name: str, log_name: str,
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
super().__init__(on_pkt, on_error, client_info, log_name) super().__init__(connection, client_info, log_name)
self._noise_psk = noise_psk self._noise_psk = noise_psk
self._expected_name = expected_name self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO) self._set_state(NoiseConnectionState.HELLO)
@ -364,7 +366,7 @@ class APINoiseFrameHelper(APIFrameHelper):
# N bytes: message data # N bytes: message data
type_high = msg[0] type_high = msg[0]
type_low = msg[1] type_low = msg[1]
self._on_pkt((type_high << 8) | type_low, msg[4:]) self._connection.process_packet((type_high << 8) | type_low, msg[4:])
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame.""" """Handle a closed frame."""

View File

@ -1,5 +1,6 @@
import cython import cython
from ..connection cimport APIConnection
from .base cimport APIFrameHelper from .base cimport APIFrameHelper

View File

@ -166,7 +166,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
packet_data = maybe_packet_data packet_data = maybe_packet_data
self._remove_from_buffer() self._remove_from_buffer()
self._on_pkt(msg_type_int, packet_data) self._connection.process_packet(msg_type_int, packet_data)
# If we have more data, continue processing # If we have more data, continue processing
def _error_on_incorrect_preamble(self, preamble: _int) -> None: def _error_on_incorrect_preamble(self, preamble: _int) -> None:

View File

@ -74,7 +74,7 @@ cdef class APIConnection:
cdef send_messages(self, tuple messages) cdef send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set) @cython.locals(handlers=set, handlers_copy=set)
cpdef _process_packet(self, object msg_type_proto, object data) cpdef process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self) cpdef _async_cancel_pong_timer(self)
@ -84,7 +84,7 @@ cdef class APIConnection:
cpdef _set_connection_state(self, object state) cpdef _set_connection_state(self, object state)
cpdef _report_fatal_error(self, Exception err) cpdef report_fatal_error(self, Exception err)
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)

View File

@ -337,8 +337,7 @@ class APIConnection:
if (noise_psk := self._params.noise_psk) is None: if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var] _, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet, connection=self,
on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
log_name=self.log_name, log_name=self.log_name,
), ),
@ -349,8 +348,7 @@ class APIConnection:
lambda: APINoiseFrameHelper( lambda: APINoiseFrameHelper(
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name=self._params.expected_name, expected_name=self._params.expected_name,
on_pkt=self._process_packet, connection=self,
on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
log_name=self.log_name, log_name=self.log_name,
), ),
@ -395,7 +393,7 @@ class APIConnection:
CONNECT_REQUEST_TIMEOUT, CONNECT_REQUEST_TIMEOUT,
) )
except TimeoutAPIError as err: except TimeoutAPIError as err:
self._report_fatal_error(err) self.report_fatal_error(err)
raise TimeoutAPIError("Hello timed out") from err raise TimeoutAPIError("Hello timed out") from err
resp = responses.pop(0) resp = responses.pop(0)
@ -499,7 +497,7 @@ class APIConnection:
self.log_name, self.log_name,
self._keep_alive_timeout, self._keep_alive_timeout,
) )
self._report_fatal_error( self.report_fatal_error(
PingFailedAPIError( PingFailedAPIError(
f"Ping response not received after {self._keep_alive_timeout} seconds" f"Ping response not received after {self._keep_alive_timeout} seconds"
) )
@ -653,7 +651,7 @@ class APIConnection:
# If writing packet fails, we don't know what state the frames # If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection # are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err) _LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
self._report_fatal_error(err) self.report_fatal_error(err)
raise raise
def _add_message_callback_without_remove( def _add_message_callback_without_remove(
@ -786,7 +784,7 @@ class APIConnection:
) )
return response return response
def _report_fatal_error(self, err: Exception) -> None: def report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occurred during an operation. """Report a fatal error that occurred during an operation.
This should only be called for errors that mean the connection This should only be called for errors that mean the connection
@ -806,8 +804,8 @@ class APIConnection:
self._fatal_exception = err self._fatal_exception = err
self._cleanup() self._cleanup()
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Factory to make a packet processor.""" """Process an incoming packet."""
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
_LOGGER.debug( _LOGGER.debug(
"%s: Skipping message type %s", "%s: Skipping message type %s",
@ -831,7 +829,7 @@ class APIConnection:
e, e,
exc_info=True, exc_info=True,
) )
self._report_fatal_error( self.report_fatal_error(
ProtocolAPIError( ProtocolAPIError(
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
) )

View File

@ -3,11 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
from datetime import timedelta from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped] from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
@ -31,6 +33,24 @@ from .common import async_fire_time_changed, utcnow
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
def report_fatal_error(self, exc: Exception):
raise exc
connection = MockConnection()
return connection, packets
class MockAPINoiseFrameHelper(APINoiseFrameHelper): class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def mock_write_frame(self, frame: bytes) -> None: def mock_write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket. """Write a packet to the socket.
@ -97,16 +117,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
) )
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
for _ in range(3): for _ in range(3):
packets = [] connection, packets = _make_mock_connection()
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = APIPlaintextFrameHelper( helper = APIPlaintextFrameHelper(
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test" connection=connection, client_info="my client", log_name="test"
) )
helper.data_received(in_bytes) helper.data_received(in_bytes)
@ -139,17 +152,10 @@ async def test_noise_frame_helper_incorrect_key():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
packets = [] connection, _ = _make_mock_connection()
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
@ -180,17 +186,10 @@ async def test_noise_frame_helper_incorrect_key_fragments():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
packets = [] connection, _ = _make_mock_connection()
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
@ -223,17 +222,10 @@ async def test_noise_incorrect_name():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
packets = [] connection, _ = _make_mock_connection()
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname", expected_name="wrongname",
client_info="my client", client_info="my client",
@ -260,17 +252,11 @@ async def test_noise_timeout():
"010000", # hello packet "010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
] ]
packets = []
def _packet(type_: int, data: bytes): connection, _ = _make_mock_connection()
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname", expected_name="wrongname",
client_info="my client", client_info="my client",
@ -317,21 +303,15 @@ async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure.""" """Test the noise frame helper handshake failure."""
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
psk_bytes = base64.b64decode(noise_psk) psk_bytes = base64.b64decode(noise_psk)
packets = []
writes = [] writes = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _writer(data: bytes): def _writer(data: bytes):
writes.append(data) writes.append(data)
def _on_error(exc: Exception): connection, _ = _make_mock_connection()
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
@ -398,21 +378,15 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
"""Test the noise frame helper handshake success with a single packet.""" """Test the noise frame helper handshake success with a single packet."""
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
psk_bytes = base64.b64decode(noise_psk) psk_bytes = base64.b64decode(noise_psk)
packets = []
writes = [] writes = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _writer(data: bytes): def _writer(data: bytes):
writes.append(data) writes.append(data)
def _on_error(exc: Exception): connection, packets = _make_mock_connection()
raise exc
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
on_pkt=_packet, connection=connection,
on_error=_on_error,
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",

View File

@ -35,8 +35,7 @@ from .common import (
def _get_mock_protocol(conn: APIConnection): def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper( protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet, connection=conn,
on_error=conn._report_fatal_error,
client_info="mock", client_info="mock",
log_name="mock_device", log_name="mock_device",
) )