From 36924784554bdf9f42996f46527e3e8c4ae76ff0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 2 Dec 2022 09:12:19 -1000 Subject: [PATCH] Optimize throughput of api to decrease latency (#327) --- aioesphomeapi/_frame_helper.py | 267 +++++++++++++++++++-------------- aioesphomeapi/connection.py | 56 ++++--- 2 files changed, 195 insertions(+), 128 deletions(-) diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 8f5250c..8d8c435 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -1,7 +1,7 @@ import asyncio import base64 import logging -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from dataclasses import dataclass from typing import Optional @@ -29,101 +29,129 @@ class Packet: class APIFrameHelper(ABC): - @abstractmethod - async def close(self) -> None: - pass + """Helper class to handle the API frame protocol.""" - @abstractmethod - async def write_packet(self, packet: Packet) -> None: - pass - - @abstractmethod - async def read_packet(self) -> Packet: - pass - - -class APIPlaintextFrameHelper(APIFrameHelper): def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - ): + ) -> None: + """Initialize the API frame helper.""" self._reader = reader self._writer = writer - self._write_lock = asyncio.Lock() - self._read_lock = asyncio.Lock() + self.read_lock = asyncio.Lock() self._closed_event = asyncio.Event() + @abstractproperty # pylint: disable=deprecated-decorator + def ready(self) -> bool: + """Return if the connection is ready.""" + + @abstractmethod async def close(self) -> None: + """Close the connection.""" + + @abstractmethod + def write_packet(self, packet: Packet) -> None: + """Write a packet to the socket.""" + + @abstractmethod + async def read_packet_with_lock(self) -> Packet: + """Read a packet from the socket, the caller is responsible for having the lock.""" + + @abstractmethod + async def wait_for_ready(self) -> None: + """Wait for the connection to be ready.""" + + +class APIPlaintextFrameHelper(APIFrameHelper): + """Frame helper for plaintext API connections.""" + + async def close(self) -> None: + """Close the connection.""" self._closed_event.set() self._writer.close() - async def write_packet(self, packet: Packet) -> None: - data = b"\0" - data += varuint_to_bytes(len(packet.data)) - data += varuint_to_bytes(packet.type) - data += packet.data + @property + def ready(self) -> bool: + """Return if the connection is ready.""" + # Plaintext is always ready + return True + + def write_packet(self, packet: Packet) -> None: + """Write a packet to the socket, the caller should not have the lock. + + The entire packet must be written in a single call to write + to avoid locking. + """ + data = ( + b"\0" + + varuint_to_bytes(len(packet.data)) + + varuint_to_bytes(packet.type) + + packet.data + ) + _LOGGER.debug("Sending plaintext frame %s", data.hex()) + try: - async with self._write_lock: - _LOGGER.debug("Sending plaintext frame %s", data.hex()) - self._writer.write(data) - await self._writer.drain() + self._writer.write(data) except (ConnectionResetError, OSError) as err: raise SocketAPIError(f"Error while writing data: {err}") from err - async def read_packet(self) -> Packet: - async with self._read_lock: - try: - # Read preamble, which should always 0x00 - # Also try to get the length and msg type - # to avoid multiple calls to readexactly - init_bytes = await self._reader.readexactly(3) - if init_bytes[0] != 0x00: - if init_bytes[0] == 0x01: - raise RequiresEncryptionAPIError( - "Connection requires encryption" - ) - raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") + async def wait_for_ready(self) -> None: + """Wait for the connection to be ready.""" + # No handshake for plaintext - if init_bytes[1] & 0x80 == 0x80: - # Length is longer than 1 byte - length = init_bytes[1:3] - msg_type = b"" - else: - # This is the most common case with 99% of messages - # needing a single byte for length and type which means - # we avoid 2 calls to readexactly - length = init_bytes[1:2] - msg_type = init_bytes[2:3] + async def read_packet_with_lock(self) -> Packet: + """Read a packet from the socket, the caller is responsible for having the lock.""" + assert self.read_lock.locked(), "read_packet_with_lock called without lock" + try: + # Read preamble, which should always 0x00 + # Also try to get the length and msg type + # to avoid multiple calls to readexactly + init_bytes = await self._reader.readexactly(3) + if init_bytes[0] != 0x00: + if init_bytes[0] == 0x01: + raise RequiresEncryptionAPIError("Connection requires encryption") + raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") - # If the message is long, we need to read the rest of the length - while length[-1] & 0x80 == 0x80: - length += await self._reader.readexactly(1) + if init_bytes[1] & 0x80 == 0x80: + # Length is longer than 1 byte + length = init_bytes[1:3] + msg_type = b"" + else: + # This is the most common case with 99% of messages + # needing a single byte for length and type which means + # we avoid 2 calls to readexactly + length = init_bytes[1:2] + msg_type = init_bytes[2:3] - # If the message length was longer than 1 byte, we need to read the - # message type - while not msg_type or (msg_type[-1] & 0x80) == 0x80: - msg_type += await self._reader.readexactly(1) + # If the message is long, we need to read the rest of the length + while length[-1] & 0x80 == 0x80: + length += await self._reader.readexactly(1) - length_int = bytes_to_varuint(length) - assert length_int is not None - msg_type_int = bytes_to_varuint(msg_type) - assert msg_type_int is not None + # If the message length was longer than 1 byte, we need to read the + # message type + while not msg_type or (msg_type[-1] & 0x80) == 0x80: + msg_type += await self._reader.readexactly(1) - if length_int == 0: - return Packet(type=msg_type_int, data=b"") + length_int = bytes_to_varuint(length) + assert length_int is not None + msg_type_int = bytes_to_varuint(msg_type) + assert msg_type_int is not None - data = await self._reader.readexactly(length_int) - return Packet(type=msg_type_int, data=data) - except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: - if ( - isinstance(err, asyncio.IncompleteReadError) - and self._closed_event.is_set() - ): - raise SocketClosedAPIError( - f"Socket closed while reading data: {err}" - ) from err - raise SocketAPIError(f"Error while reading data: {err}") from err + if length_int == 0: + return Packet(type=msg_type_int, data=b"") + + data = await self._reader.readexactly(length_int) + return Packet(type=msg_type_int, data=data) + except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: + if ( + isinstance(err, asyncio.IncompleteReadError) + and self._closed_event.is_set() + ): + raise SocketClosedAPIError( + f"Socket closed while reading data: {err}" + ) from err + raise SocketAPIError(f"Error while reading data: {err}") from err def _decode_noise_psk(psk: str) -> bytes: @@ -142,49 +170,63 @@ def _decode_noise_psk(psk: str) -> bytes: class APINoiseFrameHelper(APIFrameHelper): + """Frame helper for noise encrypted connections.""" + def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, noise_psk: str, - ): - self._reader = reader - self._writer = writer - self._write_lock = asyncio.Lock() - self._read_lock = asyncio.Lock() + ) -> None: + """Initialize the API frame helper.""" + super().__init__(reader, writer) self._ready_event = asyncio.Event() - self._closed_event = asyncio.Event() self._proto: Optional[NoiseConnection] = None self._noise_psk = noise_psk + @property + def ready(self) -> bool: + """Return if the connection is ready.""" + return self._ready_event.is_set() + async def close(self) -> None: + """Close the connection.""" + # Make sure we set the ready event if its not already set + # so that we don't block forever on the ready event if we + # are waiting for the handshake to complete. + self._ready_event.set() self._closed_event.set() self._writer.close() - async def _write_frame(self, frame: bytes) -> None: + def _write_frame(self, frame: bytes) -> None: + """Write a packet to the socket, the caller should not have the lock. + + The entire packet must be written in a single call to write + to avoid locking. + """ + _LOGGER.debug("Sending frame %s", frame.hex()) + try: - async with self._write_lock: - _LOGGER.debug("Sending frame %s", frame.hex()) - header = bytes( - [ - 0x01, - (len(frame) >> 8) & 0xFF, - len(frame) & 0xFF, - ] - ) - self._writer.write(header + frame) - await self._writer.drain() + header = bytes( + [ + 0x01, + (len(frame) >> 8) & 0xFF, + len(frame) & 0xFF, + ] + ) + self._writer.write(header + frame) except OSError as err: raise SocketAPIError(f"Error while writing data: {err}") from err - async def _read_frame(self) -> bytes: + async def _read_frame_with_lock(self) -> bytes: + """Read a frame from the socket, the caller is responsible for having the lock.""" + assert self.read_lock.locked(), "_read_frame_with_lock called without lock" try: - async with self._read_lock: - header = await self._reader.readexactly(3) - if header[0] != 0x01: - raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") - msg_size = (header[1] << 8) | header[2] - frame = await self._reader.readexactly(msg_size) + header = await self._reader.readexactly(3) + if header[0] != 0x01: + raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") + msg_size = (header[1] << 8) | header[2] + frame = await self._reader.readexactly(msg_size) except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: if ( isinstance(err, asyncio.IncompleteReadError) @@ -199,10 +241,12 @@ class APINoiseFrameHelper(APIFrameHelper): return frame async def _perform_handshake(self, expected_name: Optional[str]) -> None: - await self._write_frame(b"") # ClientHello + """Perform the handshake with the server, the caller is responsible for having the lock.""" + assert self.read_lock.locked(), "_perform_handshake called without lock" + self._write_frame(b"") # ClientHello prologue = b"NoiseAPIInit" + b"\x00\x00" - server_hello = await self._read_frame() # ServerHello + server_hello = await self._read_frame_with_lock() # ServerHello if not server_hello: raise HandshakeAPIError("ServerHello is empty") @@ -238,9 +282,9 @@ class APINoiseFrameHelper(APIFrameHelper): while not self._proto.handshake_finished: if do_write: msg = self._proto.write_message() - await self._write_frame(b"\x00" + msg) + self._write_frame(b"\x00" + msg) else: - msg = await self._read_frame() + msg = await self._read_frame_with_lock() if not msg: raise HandshakeAPIError("Handshake message too short") if msg[0] != 0: @@ -256,16 +300,16 @@ class APINoiseFrameHelper(APIFrameHelper): self._ready_event.set() async def perform_handshake(self, expected_name: Optional[str]) -> None: + """Perform the handshake with the server.""" # Allow up to 60 seconds for handhsake try: - async with async_timeout.timeout(60.0): + async with self.read_lock, async_timeout.timeout(60.0): await self._perform_handshake(expected_name) except asyncio.TimeoutError as err: raise HandshakeAPIError("Timeout during handshake") from err - async def write_packet(self, packet: Packet) -> None: - # Wait for handshake to complete - await self._ready_event.wait() + def write_packet(self, packet: Packet) -> None: + """Write a packet to the socket.""" padding = 0 data = ( bytes( @@ -281,12 +325,15 @@ class APINoiseFrameHelper(APIFrameHelper): ) assert self._proto is not None frame = self._proto.encrypt(data) - await self._write_frame(frame) + self._write_frame(frame) - async def read_packet(self) -> Packet: - # Wait for handshake to complete + async def wait_for_ready(self) -> None: + """Wait for the connection to be ready.""" await self._ready_event.wait() - frame = await self._read_frame() + + async def read_packet_with_lock(self) -> Packet: + """Read a packet from the socket, the caller is responsible for having the lock.""" + frame = await self._read_frame_with_lock() assert self._proto is not None msg = self._proto.decrypt(frame) if len(msg) < 4: diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 0688f4a..56b2187 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -49,6 +49,8 @@ _LOGGER = logging.getLogger(__name__) BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB +INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest) + @dataclass class ConnectionParams: @@ -105,7 +107,7 @@ class APIConnection: self._ping_stop_event = asyncio.Event() - self._to_process: asyncio.Queue[Packet] = asyncio.Queue() + self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue() self._process_task: Optional[asyncio.Task[None]] = None @@ -120,6 +122,9 @@ class APIConnection: async def _do_cleanup() -> None: async with self._connect_lock: + # Tell the process loop to stop + self._to_process.put_nowait(None) + if self._frame_helper is not None: await self._frame_helper.close() self._frame_helper = None @@ -388,10 +393,13 @@ class APIConnection: encoded = msg.SerializeToString() _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) + frame_helper = self._frame_helper + assert frame_helper is not None + if not frame_helper.ready: + await frame_helper.wait_for_ready() + try: - assert self._frame_helper is not None - # pylint: disable=undefined-loop-variable - await self._frame_helper.write_packet( + frame_helper.write_packet( Packet( type=message_type, data=encoded, @@ -512,48 +520,60 @@ class APIConnection: await self._cleanup() async def _process_loop(self) -> None: + to_process = self._to_process while True: - if not self._is_socket_open: - # Socket closed but task isn't cancelled yet - break - try: - pkt = await self._to_process.get() + pkt = await to_process.get() except RuntimeError: break + if pkt is None: + # Socket closed but task isn't cancelled yet + break + msg_type = pkt.type - raw_msg = pkt.data if msg_type not in MESSAGE_TYPE_TO_PROTO: _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type) continue msg = MESSAGE_TYPE_TO_PROTO[msg_type]() try: - msg.ParseFromString(raw_msg) + msg.ParseFromString(pkt.data) except Exception as e: await self._report_fatal_error( ProtocolAPIError(f"Invalid protobuf message: {e}") ) raise + _LOGGER.debug( "%s: Got message of type %s: %s", self.log_name, type(msg), msg ) for handler in self._message_handlers[:]: handler(msg) - await self._handle_internal_messages(msg) + + # Pre-check the message type to avoid awaiting + # since most messages are not internal messages + if isinstance(msg, INTERNAL_MESSAGE_TYPES): + await self._handle_internal_messages(msg) async def _read_loop(self) -> None: - assert self._frame_helper is not None + frame_helper = self._frame_helper + assert frame_helper is not None + await frame_helper.wait_for_ready() + to_process = self._to_process try: - while True: - if not self._is_socket_open: - # Socket closed but task isn't cancelled yet - break - self._to_process.put_nowait(await self._frame_helper.read_packet()) + # Once its ready, we hold the lock for the duration of the + # connection so we don't have to keep locking/unlocking + async with frame_helper.read_lock: + while True: + to_process.put_nowait(await frame_helper.read_packet_with_lock()) except SocketClosedAPIError as err: # don't log with info, if closed the site that closed the connection should log + if not self._is_socket_open: + # If we expected the socket to be closed, don't log + # the error. + return _LOGGER.debug( "%s: Socket closed, stopping read loop", self.log_name,