mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Improve performance of processing incoming packets (#573)
This commit is contained in:
parent
32c0933bfd
commit
74facc8fef
69
aioesphomeapi/connection.pxd
Normal file
69
aioesphomeapi/connection.pxd
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import cython
|
||||||
|
|
||||||
|
|
||||||
|
cdef dict MESSAGE_TYPE_TO_PROTO
|
||||||
|
cdef dict PROTO_TO_MESSAGE_TYPE
|
||||||
|
|
||||||
|
cdef set OPEN_STATES
|
||||||
|
|
||||||
|
cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
||||||
|
|
||||||
|
cdef bint TYPE_CHECKING
|
||||||
|
|
||||||
|
cdef object DISCONNECT_REQUEST_MESSAGE
|
||||||
|
cdef object PING_REQUEST_MESSAGE
|
||||||
|
cdef object PING_RESPONSE_MESSAGE
|
||||||
|
|
||||||
|
cdef object DisconnectRequest
|
||||||
|
cdef object PingRequest
|
||||||
|
cdef object GetTimeRequest
|
||||||
|
|
||||||
|
cdef class APIConnection:
|
||||||
|
|
||||||
|
cdef object _params
|
||||||
|
cdef public object on_stop
|
||||||
|
cdef object _on_stop_task
|
||||||
|
cdef public object _socket
|
||||||
|
cdef public object _frame_helper
|
||||||
|
cdef public object api_version
|
||||||
|
cdef public object _connection_state
|
||||||
|
cdef object _connect_complete
|
||||||
|
cdef dict _message_handlers
|
||||||
|
cdef public str log_name
|
||||||
|
cdef set _read_exception_futures
|
||||||
|
cdef object _ping_timer
|
||||||
|
cdef object _pong_timer
|
||||||
|
cdef float _keep_alive_interval
|
||||||
|
cdef float _keep_alive_timeout
|
||||||
|
cdef object _connect_task
|
||||||
|
cdef object _fatal_exception
|
||||||
|
cdef bint _expected_disconnect
|
||||||
|
cdef object _loop
|
||||||
|
cdef bint _send_pending_ping
|
||||||
|
cdef public bint is_connected
|
||||||
|
cdef public bint is_authenticated
|
||||||
|
cdef bint _is_socket_open
|
||||||
|
cdef object _debug_enabled
|
||||||
|
|
||||||
|
cpdef send_message(self, object msg)
|
||||||
|
|
||||||
|
@cython.locals(handlers=set, handlers_copy=set)
|
||||||
|
cpdef _process_packet(self, object msg_type_proto, object data)
|
||||||
|
|
||||||
|
cpdef _async_cancel_pong_timer(self)
|
||||||
|
|
||||||
|
cpdef _async_schedule_keep_alive(self, object now)
|
||||||
|
|
||||||
|
cpdef _cleanup(self)
|
||||||
|
|
||||||
|
cpdef _set_connection_state(self, object state)
|
||||||
|
|
||||||
|
cpdef _report_fatal_error(self, Exception err)
|
||||||
|
|
||||||
|
@cython.locals(handlers=set)
|
||||||
|
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types)
|
||||||
|
|
||||||
|
cpdef add_message_callback(self, object on_message, object msg_types)
|
||||||
|
|
||||||
|
@cython.locals(handlers=set)
|
||||||
|
cpdef _remove_message_callback(self, object on_message, object msg_types)
|
@ -59,6 +59,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
|
|||||||
|
|
||||||
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
||||||
|
|
||||||
|
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
||||||
PING_REQUEST_MESSAGE = PingRequest()
|
PING_REQUEST_MESSAGE = PingRequest()
|
||||||
PING_RESPONSE_MESSAGE = PingResponse()
|
PING_RESPONSE_MESSAGE = PingResponse()
|
||||||
|
|
||||||
@ -97,6 +98,11 @@ in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_int = int
|
||||||
|
_bytes = bytes
|
||||||
|
_float = float
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams:
|
class ConnectionParams:
|
||||||
address: str
|
address: str
|
||||||
@ -246,8 +252,15 @@ class APIConnection:
|
|||||||
self._ping_timer = None
|
self._ping_timer = None
|
||||||
|
|
||||||
if self.on_stop and self._connect_complete:
|
if self.on_stop and self._connect_complete:
|
||||||
|
# Ensure on_stop is called only once
|
||||||
|
self._on_stop_task = asyncio.create_task(
|
||||||
|
self.on_stop(self._expected_disconnect),
|
||||||
|
name=f"{self.log_name} aioesphomeapi connection on_stop",
|
||||||
|
)
|
||||||
|
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
|
||||||
|
self.on_stop = None
|
||||||
|
|
||||||
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None:
|
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
||||||
"""Remove the stop task.
|
"""Remove the stop task.
|
||||||
|
|
||||||
We need to do this because the asyncio does not hold
|
We need to do this because the asyncio does not hold
|
||||||
@ -256,14 +269,6 @@ class APIConnection:
|
|||||||
"""
|
"""
|
||||||
self._on_stop_task = None
|
self._on_stop_task = None
|
||||||
|
|
||||||
# Ensure on_stop is called only once
|
|
||||||
self._on_stop_task = asyncio.create_task(
|
|
||||||
self.on_stop(self._expected_disconnect),
|
|
||||||
name=f"{self.log_name} aioesphomeapi connection on_stop",
|
|
||||||
)
|
|
||||||
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
|
||||||
self.on_stop = None
|
|
||||||
|
|
||||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||||
"""Step 1 in connect process: resolve the address."""
|
"""Step 1 in connect process: resolve the address."""
|
||||||
try:
|
try:
|
||||||
@ -328,13 +333,12 @@ class APIConnection:
|
|||||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
|
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
|
||||||
loop = self._loop
|
loop = self._loop
|
||||||
process_packet = self._process_packet_factory()
|
|
||||||
assert self._socket is not None
|
assert self._socket is not None
|
||||||
|
|
||||||
if self._params.noise_psk is None:
|
if self._params.noise_psk is None:
|
||||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||||
lambda: APIPlaintextFrameHelper(
|
lambda: APIPlaintextFrameHelper(
|
||||||
on_pkt=process_packet,
|
on_pkt=self._process_packet,
|
||||||
on_error=self._report_fatal_error,
|
on_error=self._report_fatal_error,
|
||||||
client_info=self._params.client_info,
|
client_info=self._params.client_info,
|
||||||
log_name=self.log_name,
|
log_name=self.log_name,
|
||||||
@ -348,7 +352,7 @@ class APIConnection:
|
|||||||
lambda: APINoiseFrameHelper(
|
lambda: APINoiseFrameHelper(
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
expected_name=self._params.expected_name,
|
expected_name=self._params.expected_name,
|
||||||
on_pkt=process_packet,
|
on_pkt=self._process_packet,
|
||||||
on_error=self._report_fatal_error,
|
on_error=self._report_fatal_error,
|
||||||
client_info=self._params.client_info,
|
client_info=self._params.client_info,
|
||||||
log_name=self.log_name,
|
log_name=self.log_name,
|
||||||
@ -406,7 +410,7 @@ class APIConnection:
|
|||||||
received_name,
|
received_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _async_schedule_keep_alive(self, now: float) -> None:
|
def _async_schedule_keep_alive(self, now: _float) -> None:
|
||||||
"""Start the keep alive task."""
|
"""Start the keep alive task."""
|
||||||
self._send_pending_ping = True
|
self._send_pending_ping = True
|
||||||
self._ping_timer = self._loop.call_at(
|
self._ping_timer = self._loop.call_at(
|
||||||
@ -559,11 +563,12 @@ class APIConnection:
|
|||||||
f"Connection isn't established yet ({self._connection_state})"
|
f"Connection isn't established yet ({self._connection_state})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (message_type := PROTO_TO_MESSAGE_TYPE.get(type(msg))):
|
msg_type = type(msg)
|
||||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||||
|
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||||
|
|
||||||
if self._debug_enabled():
|
if self._debug_enabled():
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, type(msg).__name__, msg)
|
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._frame_helper is not None
|
assert self._frame_helper is not None
|
||||||
@ -578,13 +583,22 @@ class APIConnection:
|
|||||||
self._report_fatal_error(err)
|
self._report_fatal_error(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _add_message_callback_without_remove(
|
||||||
|
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
||||||
|
) -> None:
|
||||||
|
"""Add a message callback without returning a remove callable."""
|
||||||
|
message_handlers = self._message_handlers
|
||||||
|
for msg_type in msg_types:
|
||||||
|
if (handlers := message_handlers.get(msg_type)) is None:
|
||||||
|
message_handlers[msg_type] = {on_message}
|
||||||
|
else:
|
||||||
|
handlers.add(on_message)
|
||||||
|
|
||||||
def add_message_callback(
|
def add_message_callback(
|
||||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Add a message callback."""
|
"""Add a message callback."""
|
||||||
message_handlers = self._message_handlers
|
self._add_message_callback_without_remove(on_message, msg_types)
|
||||||
for msg_type in msg_types:
|
|
||||||
message_handlers.setdefault(msg_type, set()).add(on_message)
|
|
||||||
return partial(self._remove_message_callback, on_message, msg_types)
|
return partial(self._remove_message_callback, on_message, msg_types)
|
||||||
|
|
||||||
def _remove_message_callback(
|
def _remove_message_callback(
|
||||||
@ -593,7 +607,8 @@ class APIConnection:
|
|||||||
"""Remove a message callback."""
|
"""Remove a message callback."""
|
||||||
message_handlers = self._message_handlers
|
message_handlers = self._message_handlers
|
||||||
for msg_type in msg_types:
|
for msg_type in msg_types:
|
||||||
message_handlers[msg_type].discard(on_message)
|
handlers = message_handlers[msg_type]
|
||||||
|
handlers.discard(on_message)
|
||||||
|
|
||||||
def send_message_callback_response(
|
def send_message_callback_response(
|
||||||
self,
|
self,
|
||||||
@ -607,9 +622,7 @@ class APIConnection:
|
|||||||
# between sending the message and registering the handler
|
# between sending the message and registering the handler
|
||||||
# we can be sure that we will not miss any messages even though
|
# we can be sure that we will not miss any messages even though
|
||||||
# we register the handler after sending the message
|
# we register the handler after sending the message
|
||||||
for msg_type in msg_types:
|
return self.add_message_callback(on_message, msg_types)
|
||||||
self._message_handlers.setdefault(msg_type, set()).add(on_message)
|
|
||||||
return partial(self._remove_message_callback, on_message, msg_types)
|
|
||||||
|
|
||||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
||||||
"""Handle a timeout."""
|
"""Handle a timeout."""
|
||||||
@ -663,10 +676,8 @@ class APIConnection:
|
|||||||
self._handle_complex_message, fut, responses, do_append, do_stop
|
self._handle_complex_message, fut, responses, do_append, do_stop
|
||||||
)
|
)
|
||||||
|
|
||||||
message_handlers = self._message_handlers
|
|
||||||
read_exception_futures = self._read_exception_futures
|
read_exception_futures = self._read_exception_futures
|
||||||
for msg_type in msg_types:
|
self._add_message_callback_without_remove(on_message, msg_types)
|
||||||
message_handlers.setdefault(msg_type, set()).add(on_message)
|
|
||||||
|
|
||||||
read_exception_futures.add(fut)
|
read_exception_futures.add(fut)
|
||||||
# Now safe to await since we have registered the handler
|
# Now safe to await since we have registered the handler
|
||||||
@ -686,8 +697,7 @@ class APIConnection:
|
|||||||
finally:
|
finally:
|
||||||
if not timeout_expired:
|
if not timeout_expired:
|
||||||
timeout_handle.cancel()
|
timeout_handle.cancel()
|
||||||
for msg_type in msg_types:
|
self._remove_message_callback(on_message, msg_types)
|
||||||
message_handlers[msg_type].discard(on_message)
|
|
||||||
read_exception_futures.discard(fut)
|
read_exception_futures.discard(fut)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
@ -725,30 +735,24 @@ class APIConnection:
|
|||||||
self._set_connection_state(ConnectionState.CLOSED)
|
self._set_connection_state(ConnectionState.CLOSED)
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _process_packet_factory(self) -> Callable[[int, bytes], None]:
|
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||||
"""Factory to make a packet processor."""
|
"""Factory to make a packet processor."""
|
||||||
message_type_to_proto = MESSAGE_TYPE_TO_PROTO
|
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
||||||
debug_enabled = self._debug_enabled
|
|
||||||
message_handlers_get = self._message_handlers.get
|
|
||||||
internal_message_types = INTERNAL_MESSAGE_TYPES
|
|
||||||
|
|
||||||
def _process_packet(msg_type_proto: int, data: bytes) -> None:
|
|
||||||
"""Process a packet from the socket."""
|
|
||||||
try:
|
|
||||||
msg = message_type_to_proto[msg_type_proto]()
|
|
||||||
# MergeFromString instead of ParseFromString since
|
|
||||||
# ParseFromString will clear the message first and
|
|
||||||
# the msg is already empty.
|
|
||||||
msg.MergeFromString(data)
|
|
||||||
except KeyError:
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Skipping message type %s",
|
"%s: Skipping message type %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
msg_type_proto,
|
msg_type_proto,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = klass()
|
||||||
|
# MergeFromString instead of ParseFromString since
|
||||||
|
# ParseFromString will clear the message first and
|
||||||
|
# the msg is already empty.
|
||||||
|
msg.MergeFromString(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_LOGGER.info(
|
_LOGGER.error(
|
||||||
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
msg_type_proto,
|
msg_type_proto,
|
||||||
@ -765,7 +769,7 @@ class APIConnection:
|
|||||||
|
|
||||||
msg_type = type(msg)
|
msg_type = type(msg)
|
||||||
|
|
||||||
if debug_enabled():
|
if self._debug_enabled():
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Got message of type %s: %s",
|
"%s: Got message of type %s: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
@ -773,7 +777,7 @@ class APIConnection:
|
|||||||
msg,
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._pong_timer:
|
if self._pong_timer is not None:
|
||||||
# Any valid message from the remote cancels the pong timer
|
# Any valid message from the remote cancels the pong timer
|
||||||
# as we know the connection is still alive
|
# as we know the connection is still alive
|
||||||
self._async_cancel_pong_timer()
|
self._async_cancel_pong_timer()
|
||||||
@ -783,15 +787,11 @@ class APIConnection:
|
|||||||
# since we know the connection is still alive
|
# since we know the connection is still alive
|
||||||
self._send_pending_ping = False
|
self._send_pending_ping = False
|
||||||
|
|
||||||
if handlers := message_handlers_get(msg_type):
|
if (handlers := self._message_handlers.get(msg_type)) is not None:
|
||||||
for handler in handlers.copy():
|
handlers_copy = handlers.copy()
|
||||||
|
for handler in handlers_copy:
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
|
||||||
# Pre-check the message type to avoid awaiting
|
|
||||||
# since most messages are not internal messages
|
|
||||||
if msg_type not in internal_message_types:
|
|
||||||
return
|
|
||||||
|
|
||||||
if msg_type is DisconnectRequest:
|
if msg_type is DisconnectRequest:
|
||||||
self.send_message(DisconnectResponse())
|
self.send_message(DisconnectResponse())
|
||||||
self._set_connection_state(ConnectionState.CLOSED)
|
self._set_connection_state(ConnectionState.CLOSED)
|
||||||
@ -804,8 +804,6 @@ class APIConnection:
|
|||||||
resp.epoch_seconds = int(time.time())
|
resp.epoch_seconds = int(time.time())
|
||||||
self.send_message(resp)
|
self.send_message(resp)
|
||||||
|
|
||||||
return _process_packet
|
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
"""Disconnect from the API."""
|
"""Disconnect from the API."""
|
||||||
if self._connect_task:
|
if self._connect_task:
|
||||||
@ -827,7 +825,7 @@ class APIConnection:
|
|||||||
# as possible.
|
# as possible.
|
||||||
try:
|
try:
|
||||||
await self.send_message_await_response(
|
await self.send_message_await_response(
|
||||||
DisconnectRequest(), DisconnectResponse
|
DISCONNECT_REQUEST_MESSAGE, DisconnectResponse
|
||||||
)
|
)
|
||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
@ -844,7 +842,7 @@ class APIConnection:
|
|||||||
# Still try to tell the esp to disconnect gracefully
|
# Still try to tell the esp to disconnect gracefully
|
||||||
# but don't wait for it to finish
|
# but don't wait for it to finish
|
||||||
try:
|
try:
|
||||||
self.send_message(DisconnectRequest())
|
self.send_message(DISCONNECT_REQUEST_MESSAGE)
|
||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"%s: Failed to send (forced) disconnect request: %s",
|
"%s: Failed to send (forced) disconnect request: %s",
|
||||||
|
1
setup.py
1
setup.py
@ -72,6 +72,7 @@ def cythonize_if_available(setup_kwargs):
|
|||||||
dict(
|
dict(
|
||||||
ext_modules=cythonize(
|
ext_modules=cythonize(
|
||||||
[
|
[
|
||||||
|
"aioesphomeapi/connection.py",
|
||||||
"aioesphomeapi/_frame_helper/plain_text.py",
|
"aioesphomeapi/_frame_helper/plain_text.py",
|
||||||
"aioesphomeapi/_frame_helper/noise.py",
|
"aioesphomeapi/_frame_helper/noise.py",
|
||||||
"aioesphomeapi/_frame_helper/base.py",
|
"aioesphomeapi/_frame_helper/base.py",
|
||||||
|
@ -54,7 +54,7 @@ def socket_socket():
|
|||||||
|
|
||||||
def _get_mock_protocol(conn: APIConnection):
|
def _get_mock_protocol(conn: APIConnection):
|
||||||
protocol = APIPlaintextFrameHelper(
|
protocol = APIPlaintextFrameHelper(
|
||||||
on_pkt=conn._process_packet_factory(),
|
on_pkt=conn._process_packet,
|
||||||
on_error=conn._report_fatal_error,
|
on_error=conn._report_fatal_error,
|
||||||
client_info="mock",
|
client_info="mock",
|
||||||
log_name="mock_device",
|
log_name="mock_device",
|
||||||
|
Loading…
Reference in New Issue
Block a user