mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-10 10:19:30 +01:00
Refactor frame helper to avoid py conversions when processing packets (#641)
This commit is contained in:
parent
371143d383
commit
3ccb36b6fc
@ -1,30 +1,31 @@
|
||||
|
||||
import cython
|
||||
|
||||
from ..connection cimport APIConnection
|
||||
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef class APIFrameHelper:
|
||||
|
||||
cdef object _loop
|
||||
cdef object _on_pkt
|
||||
cdef object _on_error
|
||||
cdef APIConnection _connection
|
||||
cdef object _transport
|
||||
cdef public object _writer
|
||||
cdef public object _ready_future
|
||||
cdef bytes _buffer
|
||||
cdef cython.uint _buffer_len
|
||||
cdef cython.uint _pos
|
||||
cdef unsigned int _buffer_len
|
||||
cdef unsigned int _pos
|
||||
cdef object _client_info
|
||||
cdef str _log_name
|
||||
cdef object _debug_enabled
|
||||
|
||||
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
|
||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||
cdef bytes _read_exactly(self, int length)
|
||||
|
||||
cdef _add_to_buffer(self, bytes data)
|
||||
|
||||
@cython.locals(end_of_frame_pos=cython.uint)
|
||||
@cython.locals(end_of_frame_pos="unsigned int")
|
||||
cdef _remove_from_buffer(self)
|
||||
|
||||
cpdef write_packets(self, list packets)
|
||||
|
@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, Callable, cast
|
||||
|
||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import APIConnection
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SOCKET_ERRORS = (
|
||||
@ -27,8 +30,7 @@ class APIFrameHelper:
|
||||
|
||||
__slots__ = (
|
||||
"_loop",
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_connection",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_ready_future",
|
||||
@ -42,16 +44,14 @@ class APIFrameHelper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
connection: "APIConnection",
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._on_pkt = on_pkt
|
||||
self._on_error = on_error
|
||||
self._connection = connection
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._ready_future = self._loop.create_future()
|
||||
@ -143,7 +143,7 @@ class APIFrameHelper:
|
||||
self.close()
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
self._on_error(exc)
|
||||
self._connection.report_fatal_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
"""Handle the connection being lost."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import cython
|
||||
|
||||
from ..connection cimport APIConnection
|
||||
from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
|
@ -25,6 +25,9 @@ from ..core import (
|
||||
)
|
||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import APIConnection
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -81,15 +84,14 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
connection: "APIConnection",
|
||||
noise_psk: str,
|
||||
expected_name: str | None,
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error, client_info, log_name)
|
||||
super().__init__(connection, client_info, log_name)
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._set_state(NoiseConnectionState.HELLO)
|
||||
@ -364,7 +366,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
# N bytes: message data
|
||||
type_high = msg[0]
|
||||
type_low = msg[1]
|
||||
self._on_pkt((type_high << 8) | type_low, msg[4:])
|
||||
self._connection.process_packet((type_high << 8) | type_low, msg[4:])
|
||||
|
||||
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
||||
"""Handle a closed frame."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import cython
|
||||
|
||||
from ..connection cimport APIConnection
|
||||
from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
|
@ -166,7 +166,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
packet_data = maybe_packet_data
|
||||
|
||||
self._remove_from_buffer()
|
||||
self._on_pkt(msg_type_int, packet_data)
|
||||
self._connection.process_packet(msg_type_int, packet_data)
|
||||
# If we have more data, continue processing
|
||||
|
||||
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
||||
|
@ -74,7 +74,7 @@ cdef class APIConnection:
|
||||
cdef send_messages(self, tuple messages)
|
||||
|
||||
@cython.locals(handlers=set, handlers_copy=set)
|
||||
cpdef _process_packet(self, object msg_type_proto, object data)
|
||||
cpdef process_packet(self, object msg_type_proto, object data)
|
||||
|
||||
cpdef _async_cancel_pong_timer(self)
|
||||
|
||||
@ -84,7 +84,7 @@ cdef class APIConnection:
|
||||
|
||||
cpdef _set_connection_state(self, object state)
|
||||
|
||||
cpdef _report_fatal_error(self, Exception err)
|
||||
cpdef report_fatal_error(self, Exception err)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
|
@ -337,8 +337,7 @@ class APIConnection:
|
||||
if (noise_psk := self._params.noise_psk) is None:
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APIPlaintextFrameHelper(
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
connection=self,
|
||||
client_info=self._params.client_info,
|
||||
log_name=self.log_name,
|
||||
),
|
||||
@ -349,8 +348,7 @@ class APIConnection:
|
||||
lambda: APINoiseFrameHelper(
|
||||
noise_psk=noise_psk,
|
||||
expected_name=self._params.expected_name,
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
connection=self,
|
||||
client_info=self._params.client_info,
|
||||
log_name=self.log_name,
|
||||
),
|
||||
@ -395,7 +393,7 @@ class APIConnection:
|
||||
CONNECT_REQUEST_TIMEOUT,
|
||||
)
|
||||
except TimeoutAPIError as err:
|
||||
self._report_fatal_error(err)
|
||||
self.report_fatal_error(err)
|
||||
raise TimeoutAPIError("Hello timed out") from err
|
||||
|
||||
resp = responses.pop(0)
|
||||
@ -499,7 +497,7 @@ class APIConnection:
|
||||
self.log_name,
|
||||
self._keep_alive_timeout,
|
||||
)
|
||||
self._report_fatal_error(
|
||||
self.report_fatal_error(
|
||||
PingFailedAPIError(
|
||||
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
||||
)
|
||||
@ -653,7 +651,7 @@ class APIConnection:
|
||||
# If writing packet fails, we don't know what state the frames
|
||||
# are in anymore and we have to close the connection
|
||||
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
|
||||
self._report_fatal_error(err)
|
||||
self.report_fatal_error(err)
|
||||
raise
|
||||
|
||||
def _add_message_callback_without_remove(
|
||||
@ -786,7 +784,7 @@ class APIConnection:
|
||||
)
|
||||
return response
|
||||
|
||||
def _report_fatal_error(self, err: Exception) -> None:
|
||||
def report_fatal_error(self, err: Exception) -> None:
|
||||
"""Report a fatal error that occurred during an operation.
|
||||
|
||||
This should only be called for errors that mean the connection
|
||||
@ -806,8 +804,8 @@ class APIConnection:
|
||||
self._fatal_exception = err
|
||||
self._cleanup()
|
||||
|
||||
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||
"""Factory to make a packet processor."""
|
||||
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||
"""Process an incoming packet."""
|
||||
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s",
|
||||
@ -831,7 +829,7 @@ class APIConnection:
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
self._report_fatal_error(
|
||||
self.report_fatal_error(
|
||||
ProtocolAPIError(
|
||||
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
|
||||
)
|
||||
|
@ -3,11 +3,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
|
||||
from aioesphomeapi import APIConnection
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||
@ -31,6 +33,24 @@ from .common import async_fire_time_changed, utcnow
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
|
||||
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
||||
"""Make a mock connection."""
|
||||
packets: list[tuple[int, bytes]] = []
|
||||
|
||||
class MockConnection(APIConnection):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Swallow args."""
|
||||
|
||||
def process_packet(self, type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def report_fatal_error(self, exc: Exception):
|
||||
raise exc
|
||||
|
||||
connection = MockConnection()
|
||||
return connection, packets
|
||||
|
||||
|
||||
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
def mock_write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
@ -97,16 +117,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
)
|
||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
for _ in range(3):
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
connection, packets = _make_mock_connection()
|
||||
helper = APIPlaintextFrameHelper(
|
||||
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test"
|
||||
connection=connection, client_info="my client", log_name="test"
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
@ -139,17 +152,10 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
@ -180,17 +186,10 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
@ -223,17 +222,10 @@ async def test_noise_incorrect_name():
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
@ -260,17 +252,11 @@ async def test_noise_timeout():
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
@ -317,21 +303,15 @@ async def test_noise_frame_helper_handshake_failure():
|
||||
"""Test the noise frame helper handshake failure."""
|
||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
packets = []
|
||||
writes = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk=noise_psk,
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
@ -398,21 +378,15 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
"""Test the noise frame helper handshake success with a single packet."""
|
||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
packets = []
|
||||
writes = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
connection, packets = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
connection=connection,
|
||||
noise_psk=noise_psk,
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
|
@ -35,8 +35,7 @@ from .common import (
|
||||
|
||||
def _get_mock_protocol(conn: APIConnection):
|
||||
protocol = APIPlaintextFrameHelper(
|
||||
on_pkt=conn._process_packet,
|
||||
on_error=conn._report_fatal_error,
|
||||
connection=conn,
|
||||
client_info="mock",
|
||||
log_name="mock_device",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user