diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 74751b0..0d7058c 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -2,7 +2,6 @@ import asyncio import base64 import logging from abc import abstractmethod -from dataclasses import dataclass from enum import Enum from typing import Callable, Optional, Union, cast @@ -30,18 +29,12 @@ SOCKET_ERRORS = ( ) -@dataclass -class Packet: - type: int - data: bytes - - class APIFrameHelper(asyncio.Protocol): """Helper class to handle the API frame protocol.""" def __init__( self, - on_pkt: Callable[[Packet], None], + on_pkt: Callable[[int, bytes], None], on_error: Callable[[Exception], None], ) -> None: """Initialize the API frame helper.""" @@ -71,7 +64,7 @@ class APIFrameHelper(asyncio.Protocol): """Perform the handshake.""" @abstractmethod - def write_packet(self, packet: Packet) -> None: + def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket.""" def connection_made(self, transport: asyncio.BaseTransport) -> None: @@ -106,21 +99,16 @@ class APIPlaintextFrameHelper(APIFrameHelper): def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None: """Complete reading a packet from the buffer.""" del self._buffer[: self._pos] - self._on_pkt(Packet(type_, data)) + self._on_pkt(type_, data) - def write_packet(self, packet: Packet) -> None: + def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket, the caller should not have the lock. The entire packet must be written in a single call to write to avoid locking. """ assert self._transport is not None, "Transport should be set" - data = ( - b"\0" - + varuint_to_bytes(len(packet.data)) - + varuint_to_bytes(packet.type) - + packet.data - ) + data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data _LOGGER.debug("Sending plaintext frame %s", data.hex()) try: @@ -224,7 +212,7 @@ class APINoiseFrameHelper(APIFrameHelper): def __init__( self, - on_pkt: Callable[[Packet], None], + on_pkt: Callable[[int, bytes], None], on_error: Callable[[Exception], None], noise_psk: str, expected_name: Optional[str], @@ -354,20 +342,20 @@ class APINoiseFrameHelper(APIFrameHelper): self._state = NoiseConnectionState.READY self._ready_event.set() - def write_packet(self, packet: Packet) -> None: + def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket.""" self._write_frame( self._proto.encrypt( ( bytes( [ - (packet.type >> 8) & 0xFF, - (packet.type >> 0) & 0xFF, - (len(packet.data) >> 8) & 0xFF, - (len(packet.data) >> 0) & 0xFF, + (type_ >> 8) & 0xFF, + (type_ >> 0) & 0xFF, + (len(data) >> 8) & 0xFF, + (len(data) >> 0) & 0xFF, ] ) - + packet.data + + data ) ) ) @@ -383,7 +371,7 @@ class APINoiseFrameHelper(APIFrameHelper): if data_len + 4 > len(msg): raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") data = msg[4 : 4 + data_len] - return self._on_pkt(Packet(pkt_type, data)) + return self._on_pkt(pkt_type, data) def _handle_closed( # pylint: disable=unused-argument self, frame: bytearray diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index f9c0c9a..2f4cfc2 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -13,12 +13,7 @@ from google.protobuf import message import aioesphomeapi.host_resolver as hr -from ._frame_helper import ( - APIFrameHelper, - APINoiseFrameHelper, - APIPlaintextFrameHelper, - Packet, -) +from ._frame_helper import APIFrameHelper, APINoiseFrameHelper, APIPlaintextFrameHelper from .api_pb2 import ( # type: ignore ConnectRequest, ConnectResponse, @@ -498,12 +493,7 @@ class APIConnection: _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) try: - frame_helper.write_packet( - Packet( - type=message_type, - data=encoded, - ) - ) + frame_helper.write_packet(message_type, encoded) except SocketAPIError as err: # pylint: disable=broad-except # If writing packet fails, we don't know what state the frames # are in anymore and we have to close the connection @@ -646,9 +636,8 @@ class APIConnection: self._read_exception_handlers.clear() self._cleanup() - def _process_packet(self, pkt: Packet) -> None: + def _process_packet(self, msg_type_proto: int, data: bytes) -> None: """Process a packet from the socket.""" - msg_type_proto = pkt.type if msg_type_proto not in MESSAGE_TYPE_TO_PROTO: _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto) return @@ -658,19 +647,19 @@ class APIConnection: # MergeFromString instead of ParseFromString since # ParseFromString will clear the message first and # the msg is already empty. - msg.MergeFromString(pkt.data) + msg.MergeFromString(data) except Exception as e: _LOGGER.info( "%s: Invalid protobuf message: type=%s data=%s: %s", self.log_name, - pkt.type, - pkt.data, + msg_type_proto, + data, e, exc_info=True, ) self._report_fatal_error( ProtocolAPIError( - f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}" + f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" ) ) raise diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 396156e..05517fd 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock import pytest -from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet +from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi.util import varuint_to_bytes PREAMBLE = b"\x00" @@ -48,8 +48,8 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): for _ in range(5): packets = [] - def _packet(pkt: Packet): - packets.append(pkt) + def _packet(type_: int, data: bytes): + packets.append((type_, data)) def _on_error(exc: Exception): raise exc @@ -59,6 +59,7 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): helper.data_received(in_bytes) pkt = packets.pop() + type_, data = pkt - assert pkt.type == pkt_type - assert pkt.data == pkt_data + assert type_ == pkt_type + assert data == pkt_data