mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Optimize throughput of api to decrease latency (#327)
This commit is contained in:
parent
e2527878cb
commit
3692478455
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -29,50 +29,80 @@ class Packet:
|
|||||||
|
|
||||||
|
|
||||||
class APIFrameHelper(ABC):
|
class APIFrameHelper(ABC):
|
||||||
@abstractmethod
|
"""Helper class to handle the API frame protocol."""
|
||||||
async def close(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def write_packet(self, packet: Packet) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def read_packet(self) -> Packet:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reader: asyncio.StreamReader,
|
reader: asyncio.StreamReader,
|
||||||
writer: asyncio.StreamWriter,
|
writer: asyncio.StreamWriter,
|
||||||
):
|
) -> None:
|
||||||
|
"""Initialize the API frame helper."""
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
self._write_lock = asyncio.Lock()
|
self.read_lock = asyncio.Lock()
|
||||||
self._read_lock = asyncio.Lock()
|
|
||||||
self._closed_event = asyncio.Event()
|
self._closed_event = asyncio.Event()
|
||||||
|
|
||||||
|
@abstractproperty # pylint: disable=deprecated-decorator
|
||||||
|
def ready(self) -> bool:
|
||||||
|
"""Return if the connection is ready."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def write_packet(self, packet: Packet) -> None:
|
||||||
|
"""Write a packet to the socket."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read_packet_with_lock(self) -> Packet:
|
||||||
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def wait_for_ready(self) -> None:
|
||||||
|
"""Wait for the connection to be ready."""
|
||||||
|
|
||||||
|
|
||||||
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
|
"""Frame helper for plaintext API connections."""
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
self._closed_event.set()
|
self._closed_event.set()
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
|
||||||
async def write_packet(self, packet: Packet) -> None:
|
@property
|
||||||
data = b"\0"
|
def ready(self) -> bool:
|
||||||
data += varuint_to_bytes(len(packet.data))
|
"""Return if the connection is ready."""
|
||||||
data += varuint_to_bytes(packet.type)
|
# Plaintext is always ready
|
||||||
data += packet.data
|
return True
|
||||||
try:
|
|
||||||
async with self._write_lock:
|
def write_packet(self, packet: Packet) -> None:
|
||||||
|
"""Write a packet to the socket, the caller should not have the lock.
|
||||||
|
|
||||||
|
The entire packet must be written in a single call to write
|
||||||
|
to avoid locking.
|
||||||
|
"""
|
||||||
|
data = (
|
||||||
|
b"\0"
|
||||||
|
+ varuint_to_bytes(len(packet.data))
|
||||||
|
+ varuint_to_bytes(packet.type)
|
||||||
|
+ packet.data
|
||||||
|
)
|
||||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||||
|
|
||||||
|
try:
|
||||||
self._writer.write(data)
|
self._writer.write(data)
|
||||||
await self._writer.drain()
|
|
||||||
except (ConnectionResetError, OSError) as err:
|
except (ConnectionResetError, OSError) as err:
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def read_packet(self) -> Packet:
|
async def wait_for_ready(self) -> None:
|
||||||
async with self._read_lock:
|
"""Wait for the connection to be ready."""
|
||||||
|
# No handshake for plaintext
|
||||||
|
|
||||||
|
async def read_packet_with_lock(self) -> Packet:
|
||||||
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||||
|
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
|
||||||
try:
|
try:
|
||||||
# Read preamble, which should always 0x00
|
# Read preamble, which should always 0x00
|
||||||
# Also try to get the length and msg type
|
# Also try to get the length and msg type
|
||||||
@ -80,9 +110,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
init_bytes = await self._reader.readexactly(3)
|
init_bytes = await self._reader.readexactly(3)
|
||||||
if init_bytes[0] != 0x00:
|
if init_bytes[0] != 0x00:
|
||||||
if init_bytes[0] == 0x01:
|
if init_bytes[0] == 0x01:
|
||||||
raise RequiresEncryptionAPIError(
|
raise RequiresEncryptionAPIError("Connection requires encryption")
|
||||||
"Connection requires encryption"
|
|
||||||
)
|
|
||||||
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
||||||
|
|
||||||
if init_bytes[1] & 0x80 == 0x80:
|
if init_bytes[1] & 0x80 == 0x80:
|
||||||
@ -142,29 +170,43 @@ def _decode_noise_psk(psk: str) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
class APINoiseFrameHelper(APIFrameHelper):
|
class APINoiseFrameHelper(APIFrameHelper):
|
||||||
|
"""Frame helper for noise encrypted connections."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reader: asyncio.StreamReader,
|
reader: asyncio.StreamReader,
|
||||||
writer: asyncio.StreamWriter,
|
writer: asyncio.StreamWriter,
|
||||||
noise_psk: str,
|
noise_psk: str,
|
||||||
):
|
) -> None:
|
||||||
self._reader = reader
|
"""Initialize the API frame helper."""
|
||||||
self._writer = writer
|
super().__init__(reader, writer)
|
||||||
self._write_lock = asyncio.Lock()
|
|
||||||
self._read_lock = asyncio.Lock()
|
|
||||||
self._ready_event = asyncio.Event()
|
self._ready_event = asyncio.Event()
|
||||||
self._closed_event = asyncio.Event()
|
|
||||||
self._proto: Optional[NoiseConnection] = None
|
self._proto: Optional[NoiseConnection] = None
|
||||||
self._noise_psk = noise_psk
|
self._noise_psk = noise_psk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready(self) -> bool:
|
||||||
|
"""Return if the connection is ready."""
|
||||||
|
return self._ready_event.is_set()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
|
# Make sure we set the ready event if its not already set
|
||||||
|
# so that we don't block forever on the ready event if we
|
||||||
|
# are waiting for the handshake to complete.
|
||||||
|
self._ready_event.set()
|
||||||
self._closed_event.set()
|
self._closed_event.set()
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
|
||||||
async def _write_frame(self, frame: bytes) -> None:
|
def _write_frame(self, frame: bytes) -> None:
|
||||||
try:
|
"""Write a packet to the socket, the caller should not have the lock.
|
||||||
async with self._write_lock:
|
|
||||||
|
The entire packet must be written in a single call to write
|
||||||
|
to avoid locking.
|
||||||
|
"""
|
||||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||||
|
|
||||||
|
try:
|
||||||
header = bytes(
|
header = bytes(
|
||||||
[
|
[
|
||||||
0x01,
|
0x01,
|
||||||
@ -173,13 +215,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self._writer.write(header + frame)
|
self._writer.write(header + frame)
|
||||||
await self._writer.drain()
|
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def _read_frame(self) -> bytes:
|
async def _read_frame_with_lock(self) -> bytes:
|
||||||
|
"""Read a frame from the socket, the caller is responsible for having the lock."""
|
||||||
|
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
|
||||||
try:
|
try:
|
||||||
async with self._read_lock:
|
|
||||||
header = await self._reader.readexactly(3)
|
header = await self._reader.readexactly(3)
|
||||||
if header[0] != 0x01:
|
if header[0] != 0x01:
|
||||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
@ -199,10 +241,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
return frame
|
return frame
|
||||||
|
|
||||||
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
|
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||||
await self._write_frame(b"") # ClientHello
|
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||||
|
assert self.read_lock.locked(), "_perform_handshake called without lock"
|
||||||
|
self._write_frame(b"") # ClientHello
|
||||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||||
|
|
||||||
server_hello = await self._read_frame() # ServerHello
|
server_hello = await self._read_frame_with_lock() # ServerHello
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
raise HandshakeAPIError("ServerHello is empty")
|
raise HandshakeAPIError("ServerHello is empty")
|
||||||
|
|
||||||
@ -238,9 +282,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
while not self._proto.handshake_finished:
|
while not self._proto.handshake_finished:
|
||||||
if do_write:
|
if do_write:
|
||||||
msg = self._proto.write_message()
|
msg = self._proto.write_message()
|
||||||
await self._write_frame(b"\x00" + msg)
|
self._write_frame(b"\x00" + msg)
|
||||||
else:
|
else:
|
||||||
msg = await self._read_frame()
|
msg = await self._read_frame_with_lock()
|
||||||
if not msg:
|
if not msg:
|
||||||
raise HandshakeAPIError("Handshake message too short")
|
raise HandshakeAPIError("Handshake message too short")
|
||||||
if msg[0] != 0:
|
if msg[0] != 0:
|
||||||
@ -256,16 +300,16 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._ready_event.set()
|
self._ready_event.set()
|
||||||
|
|
||||||
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||||
|
"""Perform the handshake with the server."""
|
||||||
# Allow up to 60 seconds for handhsake
|
# Allow up to 60 seconds for handhsake
|
||||||
try:
|
try:
|
||||||
async with async_timeout.timeout(60.0):
|
async with self.read_lock, async_timeout.timeout(60.0):
|
||||||
await self._perform_handshake(expected_name)
|
await self._perform_handshake(expected_name)
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
raise HandshakeAPIError("Timeout during handshake") from err
|
raise HandshakeAPIError("Timeout during handshake") from err
|
||||||
|
|
||||||
async def write_packet(self, packet: Packet) -> None:
|
def write_packet(self, packet: Packet) -> None:
|
||||||
# Wait for handshake to complete
|
"""Write a packet to the socket."""
|
||||||
await self._ready_event.wait()
|
|
||||||
padding = 0
|
padding = 0
|
||||||
data = (
|
data = (
|
||||||
bytes(
|
bytes(
|
||||||
@ -281,12 +325,15 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
)
|
)
|
||||||
assert self._proto is not None
|
assert self._proto is not None
|
||||||
frame = self._proto.encrypt(data)
|
frame = self._proto.encrypt(data)
|
||||||
await self._write_frame(frame)
|
self._write_frame(frame)
|
||||||
|
|
||||||
async def read_packet(self) -> Packet:
|
async def wait_for_ready(self) -> None:
|
||||||
# Wait for handshake to complete
|
"""Wait for the connection to be ready."""
|
||||||
await self._ready_event.wait()
|
await self._ready_event.wait()
|
||||||
frame = await self._read_frame()
|
|
||||||
|
async def read_packet_with_lock(self) -> Packet:
|
||||||
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
||||||
|
frame = await self._read_frame_with_lock()
|
||||||
assert self._proto is not None
|
assert self._proto is not None
|
||||||
msg = self._proto.decrypt(frame)
|
msg = self._proto.decrypt(frame)
|
||||||
if len(msg) < 4:
|
if len(msg) < 4:
|
||||||
|
@ -49,6 +49,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
||||||
|
|
||||||
|
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams:
|
class ConnectionParams:
|
||||||
@ -105,7 +107,7 @@ class APIConnection:
|
|||||||
|
|
||||||
self._ping_stop_event = asyncio.Event()
|
self._ping_stop_event = asyncio.Event()
|
||||||
|
|
||||||
self._to_process: asyncio.Queue[Packet] = asyncio.Queue()
|
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
|
||||||
|
|
||||||
self._process_task: Optional[asyncio.Task[None]] = None
|
self._process_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
@ -120,6 +122,9 @@ class APIConnection:
|
|||||||
|
|
||||||
async def _do_cleanup() -> None:
|
async def _do_cleanup() -> None:
|
||||||
async with self._connect_lock:
|
async with self._connect_lock:
|
||||||
|
# Tell the process loop to stop
|
||||||
|
self._to_process.put_nowait(None)
|
||||||
|
|
||||||
if self._frame_helper is not None:
|
if self._frame_helper is not None:
|
||||||
await self._frame_helper.close()
|
await self._frame_helper.close()
|
||||||
self._frame_helper = None
|
self._frame_helper = None
|
||||||
@ -388,10 +393,13 @@ class APIConnection:
|
|||||||
encoded = msg.SerializeToString()
|
encoded = msg.SerializeToString()
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||||
|
|
||||||
|
frame_helper = self._frame_helper
|
||||||
|
assert frame_helper is not None
|
||||||
|
if not frame_helper.ready:
|
||||||
|
await frame_helper.wait_for_ready()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert self._frame_helper is not None
|
frame_helper.write_packet(
|
||||||
# pylint: disable=undefined-loop-variable
|
|
||||||
await self._frame_helper.write_packet(
|
|
||||||
Packet(
|
Packet(
|
||||||
type=message_type,
|
type=message_type,
|
||||||
data=encoded,
|
data=encoded,
|
||||||
@ -512,48 +520,60 @@ class APIConnection:
|
|||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
|
|
||||||
async def _process_loop(self) -> None:
|
async def _process_loop(self) -> None:
|
||||||
|
to_process = self._to_process
|
||||||
while True:
|
while True:
|
||||||
if not self._is_socket_open:
|
|
||||||
# Socket closed but task isn't cancelled yet
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pkt = await self._to_process.get()
|
pkt = await to_process.get()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if pkt is None:
|
||||||
|
# Socket closed but task isn't cancelled yet
|
||||||
|
break
|
||||||
|
|
||||||
msg_type = pkt.type
|
msg_type = pkt.type
|
||||||
raw_msg = pkt.data
|
|
||||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||||
try:
|
try:
|
||||||
msg.ParseFromString(raw_msg)
|
msg.ParseFromString(pkt.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self._report_fatal_error(
|
await self._report_fatal_error(
|
||||||
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
|
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
|
||||||
)
|
)
|
||||||
|
|
||||||
for handler in self._message_handlers[:]:
|
for handler in self._message_handlers[:]:
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
|
||||||
|
# Pre-check the message type to avoid awaiting
|
||||||
|
# since most messages are not internal messages
|
||||||
|
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
|
||||||
await self._handle_internal_messages(msg)
|
await self._handle_internal_messages(msg)
|
||||||
|
|
||||||
async def _read_loop(self) -> None:
|
async def _read_loop(self) -> None:
|
||||||
assert self._frame_helper is not None
|
frame_helper = self._frame_helper
|
||||||
|
assert frame_helper is not None
|
||||||
|
await frame_helper.wait_for_ready()
|
||||||
|
to_process = self._to_process
|
||||||
try:
|
try:
|
||||||
|
# Once its ready, we hold the lock for the duration of the
|
||||||
|
# connection so we don't have to keep locking/unlocking
|
||||||
|
async with frame_helper.read_lock:
|
||||||
while True:
|
while True:
|
||||||
if not self._is_socket_open:
|
to_process.put_nowait(await frame_helper.read_packet_with_lock())
|
||||||
# Socket closed but task isn't cancelled yet
|
|
||||||
break
|
|
||||||
self._to_process.put_nowait(await self._frame_helper.read_packet())
|
|
||||||
except SocketClosedAPIError as err:
|
except SocketClosedAPIError as err:
|
||||||
# don't log with info, if closed the site that closed the connection should log
|
# don't log with info, if closed the site that closed the connection should log
|
||||||
|
if not self._is_socket_open:
|
||||||
|
# If we expected the socket to be closed, don't log
|
||||||
|
# the error.
|
||||||
|
return
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Socket closed, stopping read loop",
|
"%s: Socket closed, stopping read loop",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user