Refactor frame helper to avoid py conversions when processing packets (#641)

This commit is contained in:
J. Nick Koston 2023-11-16 12:24:50 -06:00 committed by GitHub
parent 371143d383
commit 3ccb36b6fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 69 additions and 93 deletions

View File

@ -1,30 +1,31 @@
import cython
from ..connection cimport APIConnection
cdef bint TYPE_CHECKING
cdef class APIFrameHelper:
cdef object _loop
cdef object _on_pkt
cdef object _on_error
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _ready_future
cdef bytes _buffer
cdef cython.uint _buffer_len
cdef cython.uint _pos
cdef unsigned int _buffer_len
cdef unsigned int _pos
cdef object _client_info
cdef str _log_name
cdef object _debug_enabled
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length)
cdef _add_to_buffer(self, bytes data)
@cython.locals(end_of_frame_pos=cython.uint)
@cython.locals(end_of_frame_pos="unsigned int")
cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets)

View File

@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError
if TYPE_CHECKING:
from ..connection import APIConnection
_LOGGER = logging.getLogger(__name__)
SOCKET_ERRORS = (
@ -27,8 +30,7 @@ class APIFrameHelper:
__slots__ = (
"_loop",
"_on_pkt",
"_on_error",
"_connection",
"_transport",
"_writer",
"_ready_future",
@ -42,16 +44,14 @@ class APIFrameHelper:
def __init__(
self,
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
connection: "APIConnection",
client_info: str,
log_name: str,
) -> None:
"""Initialize the API frame helper."""
loop = asyncio.get_event_loop()
self._loop = loop
self._on_pkt = on_pkt
self._on_error = on_error
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future()
@ -143,7 +143,7 @@ class APIFrameHelper:
self.close()
def _handle_error(self, exc: Exception) -> None:
self._on_error(exc)
self._connection.report_fatal_error(exc)
def connection_lost(self, exc: Exception | None) -> None:
"""Handle the connection being lost."""

View File

@ -1,5 +1,6 @@
import cython
from ..connection cimport APIConnection
from .base cimport APIFrameHelper

View File

@ -25,6 +25,9 @@ from ..core import (
)
from .base import WRITE_EXCEPTIONS, APIFrameHelper
if TYPE_CHECKING:
from ..connection import APIConnection
_LOGGER = logging.getLogger(__name__)
@ -81,15 +84,14 @@ class APINoiseFrameHelper(APIFrameHelper):
def __init__(
self,
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
connection: "APIConnection",
noise_psk: str,
expected_name: str | None,
client_info: str,
log_name: str,
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error, client_info, log_name)
super().__init__(connection, client_info, log_name)
self._noise_psk = noise_psk
self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO)
@ -364,7 +366,7 @@ class APINoiseFrameHelper(APIFrameHelper):
# N bytes: message data
type_high = msg[0]
type_low = msg[1]
self._on_pkt((type_high << 8) | type_low, msg[4:])
self._connection.process_packet((type_high << 8) | type_low, msg[4:])
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame."""

View File

@ -1,5 +1,6 @@
import cython
from ..connection cimport APIConnection
from .base cimport APIFrameHelper

View File

@ -166,7 +166,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
packet_data = maybe_packet_data
self._remove_from_buffer()
self._on_pkt(msg_type_int, packet_data)
self._connection.process_packet(msg_type_int, packet_data)
# If we have more data, continue processing
def _error_on_incorrect_preamble(self, preamble: _int) -> None:

View File

@ -74,7 +74,7 @@ cdef class APIConnection:
cdef send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set)
cpdef _process_packet(self, object msg_type_proto, object data)
cpdef process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self)
@ -84,7 +84,7 @@ cdef class APIConnection:
cpdef _set_connection_state(self, object state)
cpdef _report_fatal_error(self, Exception err)
cpdef report_fatal_error(self, Exception err)
@cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)

View File

@ -337,8 +337,7 @@ class APIConnection:
if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
connection=self,
client_info=self._params.client_info,
log_name=self.log_name,
),
@ -349,8 +348,7 @@ class APIConnection:
lambda: APINoiseFrameHelper(
noise_psk=noise_psk,
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
connection=self,
client_info=self._params.client_info,
log_name=self.log_name,
),
@ -395,7 +393,7 @@ class APIConnection:
CONNECT_REQUEST_TIMEOUT,
)
except TimeoutAPIError as err:
self._report_fatal_error(err)
self.report_fatal_error(err)
raise TimeoutAPIError("Hello timed out") from err
resp = responses.pop(0)
@ -499,7 +497,7 @@ class APIConnection:
self.log_name,
self._keep_alive_timeout,
)
self._report_fatal_error(
self.report_fatal_error(
PingFailedAPIError(
f"Ping response not received after {self._keep_alive_timeout} seconds"
)
@ -653,7 +651,7 @@ class APIConnection:
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
self._report_fatal_error(err)
self.report_fatal_error(err)
raise
def _add_message_callback_without_remove(
@ -786,7 +784,7 @@ class APIConnection:
)
return response
def _report_fatal_error(self, err: Exception) -> None:
def report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
@ -806,8 +804,8 @@ class APIConnection:
self._fatal_exception = err
self._cleanup()
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Factory to make a packet processor."""
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet."""
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
_LOGGER.debug(
"%s: Skipping message type %s",
@ -831,7 +829,7 @@ class APIConnection:
e,
exc_info=True,
)
self._report_fatal_error(
self.report_fatal_error(
ProtocolAPIError(
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
)

View File

@ -3,11 +3,13 @@ from __future__ import annotations
import asyncio
import base64
from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock
import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
@ -31,6 +33,24 @@ from .common import async_fire_time_changed, utcnow
PREAMBLE = b"\x00"
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
def report_fatal_error(self, exc: Exception):
raise exc
connection = MockConnection()
return connection, packets
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def mock_write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket.
@ -97,16 +117,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
)
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
for _ in range(3):
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test"
connection=connection, client_info="my client", log_name="test"
)
helper.data_received(in_bytes)
@ -139,17 +152,10 @@ async def test_noise_frame_helper_incorrect_key():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
@ -180,17 +186,10 @@ async def test_noise_frame_helper_incorrect_key_fragments():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
@ -223,17 +222,10 @@ async def test_noise_incorrect_name():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
@ -260,17 +252,11 @@ async def test_noise_timeout():
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
@ -317,21 +303,15 @@ async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure."""
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
psk_bytes = base64.b64decode(noise_psk)
packets = []
writes = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _writer(data: bytes):
writes.append(data)
def _on_error(exc: Exception):
raise exc
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk=noise_psk,
expected_name="servicetest",
client_info="my client",
@ -398,21 +378,15 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
"""Test the noise frame helper handshake success with a single packet."""
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
psk_bytes = base64.b64decode(noise_psk)
packets = []
writes = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _writer(data: bytes):
writes.append(data)
def _on_error(exc: Exception):
raise exc
connection, packets = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
connection=connection,
noise_psk=noise_psk,
expected_name="servicetest",
client_info="my client",

View File

@ -35,8 +35,7 @@ from .common import (
def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet,
on_error=conn._report_fatal_error,
connection=conn,
client_info="mock",
log_name="mock_device",
)