diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd new file mode 100644 index 0000000..7dc314d --- /dev/null +++ b/aioesphomeapi/connection.pxd @@ -0,0 +1,69 @@ +import cython + + +cdef dict MESSAGE_TYPE_TO_PROTO +cdef dict PROTO_TO_MESSAGE_TYPE + +cdef set OPEN_STATES + +cdef float KEEP_ALIVE_TIMEOUT_RATIO + +cdef bint TYPE_CHECKING + +cdef object DISCONNECT_REQUEST_MESSAGE +cdef object PING_REQUEST_MESSAGE +cdef object PING_RESPONSE_MESSAGE + +cdef object DisconnectRequest +cdef object PingRequest +cdef object GetTimeRequest + +cdef class APIConnection: + + cdef object _params + cdef public object on_stop + cdef object _on_stop_task + cdef public object _socket + cdef public object _frame_helper + cdef public object api_version + cdef public object _connection_state + cdef object _connect_complete + cdef dict _message_handlers + cdef public str log_name + cdef set _read_exception_futures + cdef object _ping_timer + cdef object _pong_timer + cdef float _keep_alive_interval + cdef float _keep_alive_timeout + cdef object _connect_task + cdef object _fatal_exception + cdef bint _expected_disconnect + cdef object _loop + cdef bint _send_pending_ping + cdef public bint is_connected + cdef public bint is_authenticated + cdef bint _is_socket_open + cdef object _debug_enabled + + cpdef send_message(self, object msg) + + @cython.locals(handlers=set, handlers_copy=set) + cpdef _process_packet(self, object msg_type_proto, object data) + + cpdef _async_cancel_pong_timer(self) + + cpdef _async_schedule_keep_alive(self, object now) + + cpdef _cleanup(self) + + cpdef _set_connection_state(self, object state) + + cpdef _report_fatal_error(self, Exception err) + + @cython.locals(handlers=set) + cpdef _add_message_callback_without_remove(self, object on_message, object msg_types) + + cpdef add_message_callback(self, object on_message, object msg_types) + + @cython.locals(handlers=set) + cpdef _remove_message_callback(self, object on_message, object msg_types) \ No newline at end of file diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 6e7f634..3e0f4d1 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -59,6 +59,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest} +DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() PING_REQUEST_MESSAGE = PingRequest() PING_RESPONSE_MESSAGE = PingResponse() @@ -97,6 +98,11 @@ in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar( ) +_int = int +_bytes = bytes +_float = float + + @dataclass class ConnectionParams: address: str @@ -246,24 +252,23 @@ class APIConnection: self._ping_timer = None if self.on_stop and self._connect_complete: - - def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None: - """Remove the stop task. - - We need to do this because the asyncio does not hold - a strong reference to the task, so it can be garbage - collected unexpectedly. - """ - self._on_stop_task = None - # Ensure on_stop is called only once self._on_stop_task = asyncio.create_task( self.on_stop(self._expected_disconnect), name=f"{self.log_name} aioesphomeapi connection on_stop", ) - self._on_stop_task.add_done_callback(_remove_on_stop_task) + self._on_stop_task.add_done_callback(self._remove_on_stop_task) self.on_stop = None + def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None: + """Remove the stop task. + + We need to do this because the asyncio does not hold + a strong reference to the task, so it can be garbage + collected unexpectedly. + """ + self._on_stop_task = None + async def _connect_resolve_host(self) -> hr.AddrInfo: """Step 1 in connect process: resolve the address.""" try: @@ -328,13 +333,12 @@ class APIConnection: """Step 3 in connect process: initialize the frame helper and init read loop.""" fh: APIPlaintextFrameHelper | APINoiseFrameHelper loop = self._loop - process_packet = self._process_packet_factory() assert self._socket is not None if self._params.noise_psk is None: _, fh = await loop.create_connection( # type: ignore[type-var] lambda: APIPlaintextFrameHelper( - on_pkt=process_packet, + on_pkt=self._process_packet, on_error=self._report_fatal_error, client_info=self._params.client_info, log_name=self.log_name, @@ -348,7 +352,7 @@ class APIConnection: lambda: APINoiseFrameHelper( noise_psk=noise_psk, expected_name=self._params.expected_name, - on_pkt=process_packet, + on_pkt=self._process_packet, on_error=self._report_fatal_error, client_info=self._params.client_info, log_name=self.log_name, @@ -406,7 +410,7 @@ class APIConnection: received_name, ) - def _async_schedule_keep_alive(self, now: float) -> None: + def _async_schedule_keep_alive(self, now: _float) -> None: """Start the keep alive task.""" self._send_pending_ping = True self._ping_timer = self._loop.call_at( @@ -559,11 +563,12 @@ class APIConnection: f"Connection isn't established yet ({self._connection_state})" ) - if not (message_type := PROTO_TO_MESSAGE_TYPE.get(type(msg))): - raise ValueError(f"Message type id not found for type {type(msg)}") + msg_type = type(msg) + if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: + raise ValueError(f"Message type id not found for type {msg_type}") if self._debug_enabled(): - _LOGGER.debug("%s: Sending %s: %s", self.log_name, type(msg).__name__, msg) + _LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg) if TYPE_CHECKING: assert self._frame_helper is not None @@ -578,13 +583,22 @@ class APIConnection: self._report_fatal_error(err) raise + def _add_message_callback_without_remove( + self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] + ) -> None: + """Add a message callback without returning a remove callable.""" + message_handlers = self._message_handlers + for msg_type in msg_types: + if (handlers := message_handlers.get(msg_type)) is None: + message_handlers[msg_type] = {on_message} + else: + handlers.add(on_message) + def add_message_callback( self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] ) -> Callable[[], None]: """Add a message callback.""" - message_handlers = self._message_handlers - for msg_type in msg_types: - message_handlers.setdefault(msg_type, set()).add(on_message) + self._add_message_callback_without_remove(on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types) def _remove_message_callback( @@ -593,7 +607,8 @@ class APIConnection: """Remove a message callback.""" message_handlers = self._message_handlers for msg_type in msg_types: - message_handlers[msg_type].discard(on_message) + handlers = message_handlers[msg_type] + handlers.discard(on_message) def send_message_callback_response( self, @@ -607,9 +622,7 @@ class APIConnection: # between sending the message and registering the handler # we can be sure that we will not miss any messages even though # we register the handler after sending the message - for msg_type in msg_types: - self._message_handlers.setdefault(msg_type, set()).add(on_message) - return partial(self._remove_message_callback, on_message, msg_types) + return self.add_message_callback(on_message, msg_types) def _handle_timeout(self, fut: asyncio.Future[None]) -> None: """Handle a timeout.""" @@ -663,10 +676,8 @@ class APIConnection: self._handle_complex_message, fut, responses, do_append, do_stop ) - message_handlers = self._message_handlers read_exception_futures = self._read_exception_futures - for msg_type in msg_types: - message_handlers.setdefault(msg_type, set()).add(on_message) + self._add_message_callback_without_remove(on_message, msg_types) read_exception_futures.add(fut) # Now safe to await since we have registered the handler @@ -686,8 +697,7 @@ class APIConnection: finally: if not timeout_expired: timeout_handle.cancel() - for msg_type in msg_types: - message_handlers[msg_type].discard(on_message) + self._remove_message_callback(on_message, msg_types) read_exception_futures.discard(fut) return responses @@ -725,86 +735,74 @@ class APIConnection: self._set_connection_state(ConnectionState.CLOSED) self._cleanup() - def _process_packet_factory(self) -> Callable[[int, bytes], None]: + def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None: """Factory to make a packet processor.""" - message_type_to_proto = MESSAGE_TYPE_TO_PROTO - debug_enabled = self._debug_enabled - message_handlers_get = self._message_handlers.get - internal_message_types = INTERNAL_MESSAGE_TYPES + if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: + _LOGGER.debug( + "%s: Skipping message type %s", + self.log_name, + msg_type_proto, + ) + return - def _process_packet(msg_type_proto: int, data: bytes) -> None: - """Process a packet from the socket.""" - try: - msg = message_type_to_proto[msg_type_proto]() - # MergeFromString instead of ParseFromString since - # ParseFromString will clear the message first and - # the msg is already empty. - msg.MergeFromString(data) - except KeyError: - _LOGGER.debug( - "%s: Skipping message type %s", - self.log_name, - msg_type_proto, + try: + msg = klass() + # MergeFromString instead of ParseFromString since + # ParseFromString will clear the message first and + # the msg is already empty. + msg.MergeFromString(data) + except Exception as e: + _LOGGER.error( + "%s: Invalid protobuf message: type=%s data=%s: %s", + self.log_name, + msg_type_proto, + data, + e, + exc_info=True, + ) + self._report_fatal_error( + ProtocolAPIError( + f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" ) - return - except Exception as e: - _LOGGER.info( - "%s: Invalid protobuf message: type=%s data=%s: %s", - self.log_name, - msg_type_proto, - data, - e, - exc_info=True, - ) - self._report_fatal_error( - ProtocolAPIError( - f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" - ) - ) - raise + ) + raise - msg_type = type(msg) + msg_type = type(msg) - if debug_enabled(): - _LOGGER.debug( - "%s: Got message of type %s: %s", - self.log_name, - msg_type.__name__, - msg, - ) + if self._debug_enabled(): + _LOGGER.debug( + "%s: Got message of type %s: %s", + self.log_name, + msg_type.__name__, + msg, + ) - if self._pong_timer: - # Any valid message from the remote cancels the pong timer - # as we know the connection is still alive - self._async_cancel_pong_timer() + if self._pong_timer is not None: + # Any valid message from the remote cancels the pong timer + # as we know the connection is still alive + self._async_cancel_pong_timer() - if self._send_pending_ping: - # Any valid message from the remove cancels the pending ping - # since we know the connection is still alive - self._send_pending_ping = False + if self._send_pending_ping: + # Any valid message from the remove cancels the pending ping + # since we know the connection is still alive + self._send_pending_ping = False - if handlers := message_handlers_get(msg_type): - for handler in handlers.copy(): - handler(msg) + if (handlers := self._message_handlers.get(msg_type)) is not None: + handlers_copy = handlers.copy() + for handler in handlers_copy: + handler(msg) - # Pre-check the message type to avoid awaiting - # since most messages are not internal messages - if msg_type not in internal_message_types: - return - - if msg_type is DisconnectRequest: - self.send_message(DisconnectResponse()) - self._set_connection_state(ConnectionState.CLOSED) - self._expected_disconnect = True - self._cleanup() - elif msg_type is PingRequest: - self.send_message(PING_RESPONSE_MESSAGE) - elif msg_type is GetTimeRequest: - resp = GetTimeResponse() - resp.epoch_seconds = int(time.time()) - self.send_message(resp) - - return _process_packet + if msg_type is DisconnectRequest: + self.send_message(DisconnectResponse()) + self._set_connection_state(ConnectionState.CLOSED) + self._expected_disconnect = True + self._cleanup() + elif msg_type is PingRequest: + self.send_message(PING_RESPONSE_MESSAGE) + elif msg_type is GetTimeRequest: + resp = GetTimeResponse() + resp.epoch_seconds = int(time.time()) + self.send_message(resp) async def disconnect(self) -> None: """Disconnect from the API.""" @@ -827,7 +825,7 @@ class APIConnection: # as possible. try: await self.send_message_await_response( - DisconnectRequest(), DisconnectResponse + DISCONNECT_REQUEST_MESSAGE, DisconnectResponse ) except APIConnectionError as err: _LOGGER.error( @@ -844,7 +842,7 @@ class APIConnection: # Still try to tell the esp to disconnect gracefully # but don't wait for it to finish try: - self.send_message(DisconnectRequest()) + self.send_message(DISCONNECT_REQUEST_MESSAGE) except APIConnectionError as err: _LOGGER.error( "%s: Failed to send (forced) disconnect request: %s", diff --git a/setup.py b/setup.py index 89f0458..d104f1b 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ def cythonize_if_available(setup_kwargs): dict( ext_modules=cythonize( [ + "aioesphomeapi/connection.py", "aioesphomeapi/_frame_helper/plain_text.py", "aioesphomeapi/_frame_helper/noise.py", "aioesphomeapi/_frame_helper/base.py", diff --git a/tests/test_connection.py b/tests/test_connection.py index f42570d..8005fee 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -54,7 +54,7 @@ def socket_socket(): def _get_mock_protocol(conn: APIConnection): protocol = APIPlaintextFrameHelper( - on_pkt=conn._process_packet_factory(), + on_pkt=conn._process_packet, on_error=conn._report_fatal_error, client_info="mock", log_name="mock_device",