Improve performance of processing incoming packets (#573)

This commit is contained in:
J. Nick Koston 2023-10-13 18:01:34 -10:00 committed by GitHub
parent 32c0933bfd
commit 74facc8fef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 172 additions and 104 deletions

View File

@ -0,0 +1,69 @@
import cython
cdef dict MESSAGE_TYPE_TO_PROTO
cdef dict PROTO_TO_MESSAGE_TYPE
cdef set OPEN_STATES
cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef bint TYPE_CHECKING
cdef object DISCONNECT_REQUEST_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest
cdef class APIConnection:
cdef object _params
cdef public object on_stop
cdef object _on_stop_task
cdef public object _socket
cdef public object _frame_helper
cdef public object api_version
cdef public object _connection_state
cdef object _connect_complete
cdef dict _message_handlers
cdef public str log_name
cdef set _read_exception_futures
cdef object _ping_timer
cdef object _pong_timer
cdef float _keep_alive_interval
cdef float _keep_alive_timeout
cdef object _connect_task
cdef object _fatal_exception
cdef bint _expected_disconnect
cdef object _loop
cdef bint _send_pending_ping
cdef public bint is_connected
cdef public bint is_authenticated
cdef bint _is_socket_open
cdef object _debug_enabled
cpdef send_message(self, object msg)
@cython.locals(handlers=set, handlers_copy=set)
cpdef _process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self)
cpdef _async_schedule_keep_alive(self, object now)
cpdef _cleanup(self)
cpdef _set_connection_state(self, object state)
cpdef _report_fatal_error(self, Exception err)
@cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types)
cpdef add_message_callback(self, object on_message, object msg_types)
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, object msg_types)

View File

@ -59,6 +59,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest} INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
PING_REQUEST_MESSAGE = PingRequest() PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse() PING_RESPONSE_MESSAGE = PingResponse()
@ -97,6 +98,11 @@ in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
) )
_int = int
_bytes = bytes
_float = float
@dataclass @dataclass
class ConnectionParams: class ConnectionParams:
address: str address: str
@ -246,8 +252,15 @@ class APIConnection:
self._ping_timer = None self._ping_timer = None
if self.on_stop and self._connect_complete: if self.on_stop and self._connect_complete:
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
name=f"{self.log_name} aioesphomeapi connection on_stop",
)
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
self.on_stop = None
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None: def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
"""Remove the stop task. """Remove the stop task.
We need to do this because the asyncio does not hold We need to do this because the asyncio does not hold
@ -256,14 +269,6 @@ class APIConnection:
""" """
self._on_stop_task = None self._on_stop_task = None
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
name=f"{self.log_name} aioesphomeapi connection on_stop",
)
self._on_stop_task.add_done_callback(_remove_on_stop_task)
self.on_stop = None
async def _connect_resolve_host(self) -> hr.AddrInfo: async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address.""" """Step 1 in connect process: resolve the address."""
try: try:
@ -328,13 +333,12 @@ class APIConnection:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop loop = self._loop
process_packet = self._process_packet_factory()
assert self._socket is not None assert self._socket is not None
if self._params.noise_psk is None: if self._params.noise_psk is None:
_, fh = await loop.create_connection( # type: ignore[type-var] _, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
on_pkt=process_packet, on_pkt=self._process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
log_name=self.log_name, log_name=self.log_name,
@ -348,7 +352,7 @@ class APIConnection:
lambda: APINoiseFrameHelper( lambda: APINoiseFrameHelper(
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name=self._params.expected_name, expected_name=self._params.expected_name,
on_pkt=process_packet, on_pkt=self._process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
log_name=self.log_name, log_name=self.log_name,
@ -406,7 +410,7 @@ class APIConnection:
received_name, received_name,
) )
def _async_schedule_keep_alive(self, now: float) -> None: def _async_schedule_keep_alive(self, now: _float) -> None:
"""Start the keep alive task.""" """Start the keep alive task."""
self._send_pending_ping = True self._send_pending_ping = True
self._ping_timer = self._loop.call_at( self._ping_timer = self._loop.call_at(
@ -559,11 +563,12 @@ class APIConnection:
f"Connection isn't established yet ({self._connection_state})" f"Connection isn't established yet ({self._connection_state})"
) )
if not (message_type := PROTO_TO_MESSAGE_TYPE.get(type(msg))): msg_type = type(msg)
raise ValueError(f"Message type id not found for type {type(msg)}") if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")
if self._debug_enabled(): if self._debug_enabled():
_LOGGER.debug("%s: Sending %s: %s", self.log_name, type(msg).__name__, msg) _LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._frame_helper is not None assert self._frame_helper is not None
@ -578,13 +583,22 @@ class APIConnection:
self._report_fatal_error(err) self._report_fatal_error(err)
raise raise
def _add_message_callback_without_remove(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> None:
"""Add a message callback without returning a remove callable."""
message_handlers = self._message_handlers
for msg_type in msg_types:
if (handlers := message_handlers.get(msg_type)) is None:
message_handlers[msg_type] = {on_message}
else:
handlers.add(on_message)
def add_message_callback( def add_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a message callback.""" """Add a message callback."""
message_handlers = self._message_handlers self._add_message_callback_without_remove(on_message, msg_types)
for msg_type in msg_types:
message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self._remove_message_callback, on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types)
def _remove_message_callback( def _remove_message_callback(
@ -593,7 +607,8 @@ class APIConnection:
"""Remove a message callback.""" """Remove a message callback."""
message_handlers = self._message_handlers message_handlers = self._message_handlers
for msg_type in msg_types: for msg_type in msg_types:
message_handlers[msg_type].discard(on_message) handlers = message_handlers[msg_type]
handlers.discard(on_message)
def send_message_callback_response( def send_message_callback_response(
self, self,
@ -607,9 +622,7 @@ class APIConnection:
# between sending the message and registering the handler # between sending the message and registering the handler
# we can be sure that we will not miss any messages even though # we can be sure that we will not miss any messages even though
# we register the handler after sending the message # we register the handler after sending the message
for msg_type in msg_types: return self.add_message_callback(on_message, msg_types)
self._message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self._remove_message_callback, on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None: def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout.""" """Handle a timeout."""
@ -663,10 +676,8 @@ class APIConnection:
self._handle_complex_message, fut, responses, do_append, do_stop self._handle_complex_message, fut, responses, do_append, do_stop
) )
message_handlers = self._message_handlers
read_exception_futures = self._read_exception_futures read_exception_futures = self._read_exception_futures
for msg_type in msg_types: self._add_message_callback_without_remove(on_message, msg_types)
message_handlers.setdefault(msg_type, set()).add(on_message)
read_exception_futures.add(fut) read_exception_futures.add(fut)
# Now safe to await since we have registered the handler # Now safe to await since we have registered the handler
@ -686,8 +697,7 @@ class APIConnection:
finally: finally:
if not timeout_expired: if not timeout_expired:
timeout_handle.cancel() timeout_handle.cancel()
for msg_type in msg_types: self._remove_message_callback(on_message, msg_types)
message_handlers[msg_type].discard(on_message)
read_exception_futures.discard(fut) read_exception_futures.discard(fut)
return responses return responses
@ -725,30 +735,24 @@ class APIConnection:
self._set_connection_state(ConnectionState.CLOSED) self._set_connection_state(ConnectionState.CLOSED)
self._cleanup() self._cleanup()
def _process_packet_factory(self) -> Callable[[int, bytes], None]: def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Factory to make a packet processor.""" """Factory to make a packet processor."""
message_type_to_proto = MESSAGE_TYPE_TO_PROTO if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
debug_enabled = self._debug_enabled
message_handlers_get = self._message_handlers.get
internal_message_types = INTERNAL_MESSAGE_TYPES
def _process_packet(msg_type_proto: int, data: bytes) -> None:
"""Process a packet from the socket."""
try:
msg = message_type_to_proto[msg_type_proto]()
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
msg.MergeFromString(data)
except KeyError:
_LOGGER.debug( _LOGGER.debug(
"%s: Skipping message type %s", "%s: Skipping message type %s",
self.log_name, self.log_name,
msg_type_proto, msg_type_proto,
) )
return return
try:
msg = klass()
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
msg.MergeFromString(data)
except Exception as e: except Exception as e:
_LOGGER.info( _LOGGER.error(
"%s: Invalid protobuf message: type=%s data=%s: %s", "%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name, self.log_name,
msg_type_proto, msg_type_proto,
@ -765,7 +769,7 @@ class APIConnection:
msg_type = type(msg) msg_type = type(msg)
if debug_enabled(): if self._debug_enabled():
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", "%s: Got message of type %s: %s",
self.log_name, self.log_name,
@ -773,7 +777,7 @@ class APIConnection:
msg, msg,
) )
if self._pong_timer: if self._pong_timer is not None:
# Any valid message from the remote cancels the pong timer # Any valid message from the remote cancels the pong timer
# as we know the connection is still alive # as we know the connection is still alive
self._async_cancel_pong_timer() self._async_cancel_pong_timer()
@ -783,15 +787,11 @@ class APIConnection:
# since we know the connection is still alive # since we know the connection is still alive
self._send_pending_ping = False self._send_pending_ping = False
if handlers := message_handlers_get(msg_type): if (handlers := self._message_handlers.get(msg_type)) is not None:
for handler in handlers.copy(): handlers_copy = handlers.copy()
for handler in handlers_copy:
handler(msg) handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type not in internal_message_types:
return
if msg_type is DisconnectRequest: if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse()) self.send_message(DisconnectResponse())
self._set_connection_state(ConnectionState.CLOSED) self._set_connection_state(ConnectionState.CLOSED)
@ -804,8 +804,6 @@ class APIConnection:
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
self.send_message(resp) self.send_message(resp)
return _process_packet
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from the API.""" """Disconnect from the API."""
if self._connect_task: if self._connect_task:
@ -827,7 +825,7 @@ class APIConnection:
# as possible. # as possible.
try: try:
await self.send_message_await_response( await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse DISCONNECT_REQUEST_MESSAGE, DisconnectResponse
) )
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.error( _LOGGER.error(
@ -844,7 +842,7 @@ class APIConnection:
# Still try to tell the esp to disconnect gracefully # Still try to tell the esp to disconnect gracefully
# but don't wait for it to finish # but don't wait for it to finish
try: try:
self.send_message(DisconnectRequest()) self.send_message(DISCONNECT_REQUEST_MESSAGE)
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.error( _LOGGER.error(
"%s: Failed to send (forced) disconnect request: %s", "%s: Failed to send (forced) disconnect request: %s",

View File

@ -72,6 +72,7 @@ def cythonize_if_available(setup_kwargs):
dict( dict(
ext_modules=cythonize( ext_modules=cythonize(
[ [
"aioesphomeapi/connection.py",
"aioesphomeapi/_frame_helper/plain_text.py", "aioesphomeapi/_frame_helper/plain_text.py",
"aioesphomeapi/_frame_helper/noise.py", "aioesphomeapi/_frame_helper/noise.py",
"aioesphomeapi/_frame_helper/base.py", "aioesphomeapi/_frame_helper/base.py",

View File

@ -54,7 +54,7 @@ def socket_socket():
def _get_mock_protocol(conn: APIConnection): def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper( protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet_factory(), on_pkt=conn._process_packet,
on_error=conn._report_fatal_error, on_error=conn._report_fatal_error,
client_info="mock", client_info="mock",
log_name="mock_device", log_name="mock_device",