mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-25 22:12:28 +01:00
Improve data throughput (#421)
This commit is contained in:
parent
d1951ebd90
commit
a539a6e950
@ -2,7 +2,6 @@ import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
@ -30,18 +29,12 @@ SOCKET_ERRORS = (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
type: int
|
||||
data: bytes
|
||||
|
||||
|
||||
class APIFrameHelper(asyncio.Protocol):
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[Packet], None],
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
@ -71,7 +64,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"""Perform the handshake."""
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
@ -106,21 +99,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
|
||||
"""Complete reading a packet from the buffer."""
|
||||
del self._buffer[: self._pos]
|
||||
self._on_pkt(Packet(type_, data))
|
||||
self._on_pkt(type_, data)
|
||||
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
"""
|
||||
assert self._transport is not None, "Transport should be set"
|
||||
data = (
|
||||
b"\0"
|
||||
+ varuint_to_bytes(len(packet.data))
|
||||
+ varuint_to_bytes(packet.type)
|
||||
+ packet.data
|
||||
)
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||
|
||||
try:
|
||||
@ -224,7 +212,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[Packet], None],
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
noise_psk: str,
|
||||
expected_name: Optional[str],
|
||||
@ -354,20 +342,20 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._state = NoiseConnectionState.READY
|
||||
self._ready_event.set()
|
||||
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
self._write_frame(
|
||||
self._proto.encrypt(
|
||||
(
|
||||
bytes(
|
||||
[
|
||||
(packet.type >> 8) & 0xFF,
|
||||
(packet.type >> 0) & 0xFF,
|
||||
(len(packet.data) >> 8) & 0xFF,
|
||||
(len(packet.data) >> 0) & 0xFF,
|
||||
(type_ >> 8) & 0xFF,
|
||||
(type_ >> 0) & 0xFF,
|
||||
(len(data) >> 8) & 0xFF,
|
||||
(len(data) >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ packet.data
|
||||
+ data
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -383,7 +371,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if data_len + 4 > len(msg):
|
||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
data = msg[4 : 4 + data_len]
|
||||
return self._on_pkt(Packet(pkt_type, data))
|
||||
return self._on_pkt(pkt_type, data)
|
||||
|
||||
def _handle_closed( # pylint: disable=unused-argument
|
||||
self, frame: bytearray
|
||||
|
@ -13,12 +13,7 @@ from google.protobuf import message
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
||||
from ._frame_helper import (
|
||||
APIFrameHelper,
|
||||
APINoiseFrameHelper,
|
||||
APIPlaintextFrameHelper,
|
||||
Packet,
|
||||
)
|
||||
from ._frame_helper import APIFrameHelper, APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from .api_pb2 import ( # type: ignore
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
@ -498,12 +493,7 @@ class APIConnection:
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
|
||||
try:
|
||||
frame_helper.write_packet(
|
||||
Packet(
|
||||
type=message_type,
|
||||
data=encoded,
|
||||
)
|
||||
)
|
||||
frame_helper.write_packet(message_type, encoded)
|
||||
except SocketAPIError as err: # pylint: disable=broad-except
|
||||
# If writing packet fails, we don't know what state the frames
|
||||
# are in anymore and we have to close the connection
|
||||
@ -646,9 +636,8 @@ class APIConnection:
|
||||
self._read_exception_handlers.clear()
|
||||
self._cleanup()
|
||||
|
||||
def _process_packet(self, pkt: Packet) -> None:
|
||||
def _process_packet(self, msg_type_proto: int, data: bytes) -> None:
|
||||
"""Process a packet from the socket."""
|
||||
msg_type_proto = pkt.type
|
||||
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto)
|
||||
return
|
||||
@ -658,19 +647,19 @@ class APIConnection:
|
||||
# MergeFromString instead of ParseFromString since
|
||||
# ParseFromString will clear the message first and
|
||||
# the msg is already empty.
|
||||
msg.MergeFromString(pkt.data)
|
||||
msg.MergeFromString(data)
|
||||
except Exception as e:
|
||||
_LOGGER.info(
|
||||
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
||||
self.log_name,
|
||||
pkt.type,
|
||||
pkt.data,
|
||||
msg_type_proto,
|
||||
data,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
self._report_fatal_error(
|
||||
ProtocolAPIError(
|
||||
f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}"
|
||||
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
@ -3,7 +3,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.util import varuint_to_bytes
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
@ -48,8 +48,8 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
for _ in range(5):
|
||||
packets = []
|
||||
|
||||
def _packet(pkt: Packet):
|
||||
packets.append(pkt)
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
@ -59,6 +59,7 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
helper.data_received(in_bytes)
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
||||
assert pkt.type == pkt_type
|
||||
assert pkt.data == pkt_data
|
||||
assert type_ == pkt_type
|
||||
assert data == pkt_data
|
||||
|
Loading…
Reference in New Issue
Block a user