Improve data throughput (#421)

This commit is contained in:
J. Nick Koston 2023-04-19 20:47:38 -10:00 committed by GitHub
parent d1951ebd90
commit a539a6e950
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 48 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import base64
import logging
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional, Union, cast
@ -30,18 +29,12 @@ SOCKET_ERRORS = (
)
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol."""
def __init__(
self,
on_pkt: Callable[[Packet], None],
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
) -> None:
"""Initialize the API frame helper."""
@ -71,7 +64,7 @@ class APIFrameHelper(asyncio.Protocol):
"""Perform the handshake."""
@abstractmethod
def write_packet(self, packet: Packet) -> None:
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
@ -106,21 +99,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
self._on_pkt(Packet(type_, data))
self._on_pkt(type_, data)
def write_packet(self, packet: Packet) -> None:
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
assert self._transport is not None, "Transport should be set"
data = (
b"\0"
+ varuint_to_bytes(len(packet.data))
+ varuint_to_bytes(packet.type)
+ packet.data
)
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
_LOGGER.debug("Sending plaintext frame %s", data.hex())
try:
@ -224,7 +212,7 @@ class APINoiseFrameHelper(APIFrameHelper):
def __init__(
self,
on_pkt: Callable[[Packet], None],
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
noise_psk: str,
expected_name: Optional[str],
@ -354,20 +342,20 @@ class APINoiseFrameHelper(APIFrameHelper):
self._state = NoiseConnectionState.READY
self._ready_event.set()
def write_packet(self, packet: Packet) -> None:
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
self._write_frame(
self._proto.encrypt(
(
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
(type_ >> 8) & 0xFF,
(type_ >> 0) & 0xFF,
(len(data) >> 8) & 0xFF,
(len(data) >> 0) & 0xFF,
]
)
+ packet.data
+ data
)
)
)
@ -383,7 +371,7 @@ class APINoiseFrameHelper(APIFrameHelper):
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return self._on_pkt(Packet(pkt_type, data))
return self._on_pkt(pkt_type, data)
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray

View File

@ -13,12 +13,7 @@ from google.protobuf import message
import aioesphomeapi.host_resolver as hr
from ._frame_helper import (
APIFrameHelper,
APINoiseFrameHelper,
APIPlaintextFrameHelper,
Packet,
)
from ._frame_helper import APIFrameHelper, APINoiseFrameHelper, APIPlaintextFrameHelper
from .api_pb2 import ( # type: ignore
ConnectRequest,
ConnectResponse,
@ -498,12 +493,7 @@ class APIConnection:
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
try:
frame_helper.write_packet(
Packet(
type=message_type,
data=encoded,
)
)
frame_helper.write_packet(message_type, encoded)
except SocketAPIError as err: # pylint: disable=broad-except
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
@ -646,9 +636,8 @@ class APIConnection:
self._read_exception_handlers.clear()
self._cleanup()
def _process_packet(self, pkt: Packet) -> None:
def _process_packet(self, msg_type_proto: int, data: bytes) -> None:
"""Process a packet from the socket."""
msg_type_proto = pkt.type
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto)
return
@ -658,19 +647,19 @@ class APIConnection:
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
msg.MergeFromString(pkt.data)
msg.MergeFromString(data)
except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name,
pkt.type,
pkt.data,
msg_type_proto,
data,
e,
exc_info=True,
)
self._report_fatal_error(
ProtocolAPIError(
f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}"
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
)
)
raise

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock
import pytest
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00"
@ -48,8 +48,8 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
for _ in range(5):
packets = []
def _packet(pkt: Packet):
packets.append(pkt)
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
@ -59,6 +59,7 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
helper.data_received(in_bytes)
pkt = packets.pop()
type_, data = pkt
assert pkt.type == pkt_type
assert pkt.data == pkt_data
assert type_ == pkt_type
assert data == pkt_data