mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Optimize throughput of api to decrease latency (#327)
This commit is contained in:
parent
e2527878cb
commit
3692478455
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -29,50 +29,80 @@ class Packet:
|
||||
|
||||
|
||||
class APIFrameHelper(ABC):
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read_packet(self) -> Packet:
|
||||
pass
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
self.read_lock = asyncio.Lock()
|
||||
self._closed_event = asyncio.Event()
|
||||
|
||||
@abstractproperty # pylint: disable=deprecated-decorator
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
|
||||
@abstractmethod
|
||||
async def read_packet_with_lock(self) -> Packet:
|
||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_ready(self) -> None:
|
||||
"""Wait for the connection to be ready."""
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
data = b"\0"
|
||||
data += varuint_to_bytes(len(packet.data))
|
||||
data += varuint_to_bytes(packet.type)
|
||||
data += packet.data
|
||||
try:
|
||||
async with self._write_lock:
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
# Plaintext is always ready
|
||||
return True
|
||||
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
"""
|
||||
data = (
|
||||
b"\0"
|
||||
+ varuint_to_bytes(len(packet.data))
|
||||
+ varuint_to_bytes(packet.type)
|
||||
+ packet.data
|
||||
)
|
||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||
|
||||
try:
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
except (ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
async with self._read_lock:
|
||||
async def wait_for_ready(self) -> None:
|
||||
"""Wait for the connection to be ready."""
|
||||
# No handshake for plaintext
|
||||
|
||||
async def read_packet_with_lock(self) -> Packet:
|
||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
|
||||
try:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
@ -80,9 +110,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
init_bytes = await self._reader.readexactly(3)
|
||||
if init_bytes[0] != 0x00:
|
||||
if init_bytes[0] == 0x01:
|
||||
raise RequiresEncryptionAPIError(
|
||||
"Connection requires encryption"
|
||||
)
|
||||
raise RequiresEncryptionAPIError("Connection requires encryption")
|
||||
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
||||
|
||||
if init_bytes[1] & 0x80 == 0x80:
|
||||
@ -142,29 +170,43 @@ def _decode_noise_psk(psk: str) -> bytes:
|
||||
|
||||
|
||||
class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for noise encrypted connections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
noise_psk: str,
|
||||
):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(reader, writer)
|
||||
self._ready_event = asyncio.Event()
|
||||
self._closed_event = asyncio.Event()
|
||||
self._proto: Optional[NoiseConnection] = None
|
||||
self._noise_psk = noise_psk
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
return self._ready_event.is_set()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
# Make sure we set the ready event if its not already set
|
||||
# so that we don't block forever on the ready event if we
|
||||
# are waiting for the handshake to complete.
|
||||
self._ready_event.set()
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def _write_frame(self, frame: bytes) -> None:
|
||||
try:
|
||||
async with self._write_lock:
|
||||
def _write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
"""
|
||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||
|
||||
try:
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
@ -173,13 +215,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
]
|
||||
)
|
||||
self._writer.write(header + frame)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def _read_frame(self) -> bytes:
|
||||
async def _read_frame_with_lock(self) -> bytes:
|
||||
"""Read a frame from the socket, the caller is responsible for having the lock."""
|
||||
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
|
||||
try:
|
||||
async with self._read_lock:
|
||||
header = await self._reader.readexactly(3)
|
||||
if header[0] != 0x01:
|
||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
@ -199,10 +241,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
return frame
|
||||
|
||||
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||
await self._write_frame(b"") # ClientHello
|
||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||
assert self.read_lock.locked(), "_perform_handshake called without lock"
|
||||
self._write_frame(b"") # ClientHello
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
|
||||
server_hello = await self._read_frame() # ServerHello
|
||||
server_hello = await self._read_frame_with_lock() # ServerHello
|
||||
if not server_hello:
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
|
||||
@ -238,9 +282,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
while not self._proto.handshake_finished:
|
||||
if do_write:
|
||||
msg = self._proto.write_message()
|
||||
await self._write_frame(b"\x00" + msg)
|
||||
self._write_frame(b"\x00" + msg)
|
||||
else:
|
||||
msg = await self._read_frame()
|
||||
msg = await self._read_frame_with_lock()
|
||||
if not msg:
|
||||
raise HandshakeAPIError("Handshake message too short")
|
||||
if msg[0] != 0:
|
||||
@ -256,16 +300,16 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._ready_event.set()
|
||||
|
||||
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
# Allow up to 60 seconds for handhsake
|
||||
try:
|
||||
async with async_timeout.timeout(60.0):
|
||||
async with self.read_lock, async_timeout.timeout(60.0):
|
||||
await self._perform_handshake(expected_name)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError("Timeout during handshake") from err
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
# Wait for handshake to complete
|
||||
await self._ready_event.wait()
|
||||
def write_packet(self, packet: Packet) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
padding = 0
|
||||
data = (
|
||||
bytes(
|
||||
@ -281,12 +325,15 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
assert self._proto is not None
|
||||
frame = self._proto.encrypt(data)
|
||||
await self._write_frame(frame)
|
||||
self._write_frame(frame)
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
# Wait for handshake to complete
|
||||
async def wait_for_ready(self) -> None:
|
||||
"""Wait for the connection to be ready."""
|
||||
await self._ready_event.wait()
|
||||
frame = await self._read_frame()
|
||||
|
||||
async def read_packet_with_lock(self) -> Packet:
|
||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||
frame = await self._read_frame_with_lock()
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(frame)
|
||||
if len(msg) < 4:
|
||||
|
@ -49,6 +49,8 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
||||
|
||||
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
@ -105,7 +107,7 @@ class APIConnection:
|
||||
|
||||
self._ping_stop_event = asyncio.Event()
|
||||
|
||||
self._to_process: asyncio.Queue[Packet] = asyncio.Queue()
|
||||
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
|
||||
|
||||
self._process_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
@ -120,6 +122,9 @@ class APIConnection:
|
||||
|
||||
async def _do_cleanup() -> None:
|
||||
async with self._connect_lock:
|
||||
# Tell the process loop to stop
|
||||
self._to_process.put_nowait(None)
|
||||
|
||||
if self._frame_helper is not None:
|
||||
await self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
@ -388,10 +393,13 @@ class APIConnection:
|
||||
encoded = msg.SerializeToString()
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
|
||||
frame_helper = self._frame_helper
|
||||
assert frame_helper is not None
|
||||
if not frame_helper.ready:
|
||||
await frame_helper.wait_for_ready()
|
||||
|
||||
try:
|
||||
assert self._frame_helper is not None
|
||||
# pylint: disable=undefined-loop-variable
|
||||
await self._frame_helper.write_packet(
|
||||
frame_helper.write_packet(
|
||||
Packet(
|
||||
type=message_type,
|
||||
data=encoded,
|
||||
@ -512,48 +520,60 @@ class APIConnection:
|
||||
await self._cleanup()
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
to_process = self._to_process
|
||||
while True:
|
||||
if not self._is_socket_open:
|
||||
# Socket closed but task isn't cancelled yet
|
||||
break
|
||||
|
||||
try:
|
||||
pkt = await self._to_process.get()
|
||||
pkt = await to_process.get()
|
||||
except RuntimeError:
|
||||
break
|
||||
|
||||
if pkt is None:
|
||||
# Socket closed but task isn't cancelled yet
|
||||
break
|
||||
|
||||
msg_type = pkt.type
|
||||
raw_msg = pkt.data
|
||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
||||
continue
|
||||
|
||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||
try:
|
||||
msg.ParseFromString(raw_msg)
|
||||
msg.ParseFromString(pkt.data)
|
||||
except Exception as e:
|
||||
await self._report_fatal_error(
|
||||
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
||||
)
|
||||
raise
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
|
||||
)
|
||||
|
||||
for handler in self._message_handlers[:]:
|
||||
handler(msg)
|
||||
|
||||
# Pre-check the message type to avoid awaiting
|
||||
# since most messages are not internal messages
|
||||
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
|
||||
await self._handle_internal_messages(msg)
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
assert self._frame_helper is not None
|
||||
frame_helper = self._frame_helper
|
||||
assert frame_helper is not None
|
||||
await frame_helper.wait_for_ready()
|
||||
to_process = self._to_process
|
||||
try:
|
||||
# Once its ready, we hold the lock for the duration of the
|
||||
# connection so we don't have to keep locking/unlocking
|
||||
async with frame_helper.read_lock:
|
||||
while True:
|
||||
if not self._is_socket_open:
|
||||
# Socket closed but task isn't cancelled yet
|
||||
break
|
||||
self._to_process.put_nowait(await self._frame_helper.read_packet())
|
||||
to_process.put_nowait(await frame_helper.read_packet_with_lock())
|
||||
except SocketClosedAPIError as err:
|
||||
# don't log with info, if closed the site that closed the connection should log
|
||||
if not self._is_socket_open:
|
||||
# If we expected the socket to be closed, don't log
|
||||
# the error.
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Socket closed, stopping read loop",
|
||||
self.log_name,
|
||||
|
Loading…
Reference in New Issue
Block a user