diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 2f16b11..9f68caa 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -42,3 +42,5 @@ cdef class APIFrameHelper: cpdef void write_packets(self, list packets, bint debug_enabled) cdef void _write_bytes(self, object data, bint debug_enabled) + + cdef get_handshake_future(self) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 597f41b..27822bc 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -5,7 +5,7 @@ import logging from abc import abstractmethod from typing import TYPE_CHECKING, Callable, cast -from ..core import HandshakeAPIError, SocketClosedAPIError +from ..core import SocketClosedAPIError if TYPE_CHECKING: from ..connection import APIConnection @@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError) _int = int _bytes = bytes +_float = float class APIFrameHelper: @@ -135,21 +136,9 @@ class APIFrameHelper: bitpos += 7 return -1 - async def perform_handshake(self, timeout: float) -> None: - """Perform the handshake with the server.""" - handshake_handle = self._loop.call_at( - self._loop.time() + timeout, - self._set_ready_future_exception, - asyncio.TimeoutError, - ) - try: - await self._ready_future - except asyncio.TimeoutError as err: - raise HandshakeAPIError( - f"{self._log_name}: Timeout during handshake" - ) from err - finally: - handshake_handle.cancel() + def get_handshake_future(self) -> None: + """Get the handshake future.""" + return self._ready_future @abstractmethod def write_packets( diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 9d420ac..b5f7472 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -9,6 +9,7 @@ cdef dict PROTO_TO_MESSAGE_TYPE cdef set OPEN_STATES cdef float KEEP_ALIVE_TIMEOUT_RATIO +cdef object HANDSHAKE_TIMEOUT cdef bint TYPE_CHECKING @@ -16,12 +17,13 @@ cdef object DISCONNECT_REQUEST_MESSAGE cdef object DISCONNECT_RESPONSE_MESSAGE cdef object PING_REQUEST_MESSAGE cdef object PING_RESPONSE_MESSAGE +cdef object NO_PASSWORD_CONNECT_REQUEST cdef object asyncio_timeout cdef object CancelledError cdef object asyncio_TimeoutError -cdef object ConnectResponse +cdef object ConnectRequest, ConnectResponse cdef object DisconnectRequest cdef object PingRequest cdef object GetTimeRequest, GetTimeResponse @@ -53,6 +55,10 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE cdef object CONNECTION_STATE_CONNECTED cdef object CONNECTION_STATE_CLOSED +cdef object make_hello_request +cdef object handle_timeout +cdef object handle_complex_message + @cython.dataclasses.dataclass cdef class ConnectionParams: cdef public str address @@ -91,43 +97,45 @@ cdef class APIConnection: cdef public str received_name cdef public object resolved_addr_info - cpdef send_message(self, object msg) + cpdef void send_message(self, object msg) - cdef send_messages(self, tuple messages) + cdef void send_messages(self, tuple messages) @cython.locals(handlers=set, handlers_copy=set) cpdef void process_packet(self, object msg_type_proto, object data) - cpdef _async_cancel_pong_timer(self) + cdef void _async_cancel_pong_timer(self) - cpdef _async_schedule_keep_alive(self, object now) + cdef void _async_schedule_keep_alive(self, object now) - cdef _cleanup(self) + cdef void _cleanup(self) cpdef set_log_name(self, str name) cdef _make_connect_request(self) - cdef _process_hello_resp(self, object resp) + cdef void _process_hello_resp(self, object resp) - cdef _process_login_response(self, object hello_response) + cdef void _process_login_response(self, object hello_response) - cdef _set_connection_state(self, object state) + cdef void _set_connection_state(self, object state) cpdef report_fatal_error(self, Exception err) @cython.locals(handlers=set) - cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) + cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types) cpdef add_message_callback(self, object on_message, tuple msg_types) @cython.locals(handlers=set) - cpdef _remove_message_callback(self, object on_message, tuple msg_types) + cpdef void _remove_message_callback(self, object on_message, tuple msg_types) - cpdef _handle_disconnect_request_internal(self, object msg) + cpdef void _handle_disconnect_request_internal(self, object msg) - cpdef _handle_ping_request_internal(self, object msg) + cpdef void _handle_ping_request_internal(self, object msg) - cpdef _handle_get_time_request_internal(self, object msg) + cpdef void _handle_get_time_request_internal(self, object msg) - cdef _set_fatal_exception_if_unset(self, Exception err) + cdef void _set_fatal_exception_if_unset(self, Exception err) + + cdef void _register_internal_message_handlers(self) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 521c2a3..a3816ca 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -12,7 +12,7 @@ import time from asyncio import CancelledError from asyncio import TimeoutError as asyncio_TimeoutError from dataclasses import astuple, dataclass -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Callable from google.protobuf import message @@ -66,6 +66,7 @@ DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse() PING_REQUEST_MESSAGE = PingRequest() PING_RESPONSE_MESSAGE = PingResponse() +NO_PASSWORD_CONNECT_REQUEST = ConnectRequest() PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} @@ -131,6 +132,46 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED CONNECTION_STATE_CLOSED = ConnectionState.CLOSED +def _make_hello_request(client_info: str) -> HelloRequest: + """Make a HelloRequest.""" + return HelloRequest( + client_info=client_info, + api_version_major=1, + api_version_minor=9, + ) + + +_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request) +make_hello_request = _cached_make_hello_request + + +def _handle_timeout(fut: asyncio.Future[None]) -> None: + """Handle a timeout.""" + if not fut.done(): + fut.set_exception(asyncio_TimeoutError) + + +handle_timeout = _handle_timeout + + +def _handle_complex_message( + fut: asyncio.Future[None], + responses: list[message.Message], + do_append: Callable[[message.Message], bool] | None, + do_stop: Callable[[message.Message], bool] | None, + resp: message.Message, +) -> None: + """Handle a message that is part of a response.""" + if not fut.done(): + if do_append is None or do_append(resp): + responses.append(resp) + if do_stop is None or do_stop(resp): + fut.set_result(None) + + +handle_complex_message = _handle_complex_message + + class APIConnection: """This class represents _one_ connection to a remote native API device. @@ -359,24 +400,23 @@ class APIConnection: # Set the frame helper right away to ensure # the socket gets closed if we fail to handshake self._frame_helper = fh - + future = self._frame_helper.get_handshake_future() + handshake_handle = self._loop.call_at( + self._loop.time() + HANDSHAKE_TIMEOUT, handle_timeout, future + ) try: - await fh.perform_handshake(HANDSHAKE_TIMEOUT) + await future except asyncio_TimeoutError as err: raise TimeoutAPIError("Handshake timed out") from err except OSError as err: raise HandshakeAPIError(f"Handshake failed: {err}") from err + finally: + handshake_handle.cancel() self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE) async def _connect_hello_login(self, login: bool) -> None: """Step 4 in connect process: send hello and login and get api version.""" - messages = [ - HelloRequest( - client_info=self._params.client_info, - api_version_major=1, - api_version_minor=9, - ) - ] + messages = [make_hello_request(self._params.client_info)] msg_types = [HelloResponse] if login: messages.append(self._make_connect_request()) @@ -600,10 +640,9 @@ class APIConnection: def _make_connect_request(self) -> ConnectRequest: """Make a ConnectRequest.""" - connect = ConnectRequest() if self._params.password is not None: - connect.password = self._params.password - return connect + return ConnectRequest(password=self._params.password) + return NO_PASSWORD_CONNECT_REQUEST def send_message(self, msg: message.Message) -> None: """Send a message to the remote.""" @@ -679,26 +718,6 @@ class APIConnection: # we register the handler after sending the message return self.add_message_callback(on_message, msg_types) - def _handle_timeout(self, fut: asyncio.Future[None]) -> None: - """Handle a timeout.""" - if not fut.done(): - fut.set_exception(asyncio_TimeoutError) - - def _handle_complex_message( - self, - fut: asyncio.Future[None], - responses: list[message.Message], - do_append: Callable[[message.Message], bool] | None, - do_stop: Callable[[message.Message], bool] | None, - resp: message.Message, - ) -> None: - """Handle a message that is part of a response.""" - if not fut.done(): - if do_append is None or do_append(resp): - responses.append(resp) - if do_stop is None or do_stop(resp): - fut.set_result(None) - async def send_messages_await_response_complex( # pylint: disable=too-many-locals self, messages: tuple[message.Message, ...], @@ -724,19 +743,16 @@ class APIConnection: # Unsafe to await between sending the message and registering the handler fut: asyncio.Future[None] = loop.create_future() responses: list[message.Message] = [] - handler = self._handle_complex_message - on_message = partial(handler, fut, responses, do_append, do_stop) - - read_exception_futures = self._read_exception_futures + on_message = partial(handle_complex_message, fut, responses, do_append, do_stop) self._add_message_callback_without_remove(on_message, msg_types) - read_exception_futures.add(fut) + self._read_exception_futures.add(fut) # Now safe to await since we have registered the handler # We must not await without a finally or # the message could fail to be removed if the # the await is cancelled - timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut) + timeout_handle = loop.call_at(loop.time() + timeout, handle_timeout, fut) timeout_expired = False try: await fut @@ -750,7 +766,7 @@ class APIConnection: if not timeout_expired: timeout_handle.cancel() self._remove_message_callback(on_message, msg_types) - read_exception_futures.discard(fut) + self._read_exception_futures.discard(fut) return responses @@ -775,7 +791,7 @@ class APIConnection: The connection will be closed, all exception handlers notified. This method does not log the error, the call site should do so. """ - if not self._fatal_exception: + if self._fatal_exception is None: if self._expected_disconnect is False: # Only log the first error _LOGGER.warning( @@ -810,7 +826,7 @@ class APIConnection: return try: - msg = klass() + msg: message.Message = klass() # MergeFromString instead of ParseFromString since # ParseFromString will clear the message first and # the msg is already empty. @@ -895,7 +911,7 @@ class APIConnection: async def disconnect(self) -> None: """Disconnect from the API.""" - if self._finish_connect_task: + if self._finish_connect_task is not None: # Try to wait for the handshake to finish so we can send # a disconnect request. If it doesn't finish in time # we will just close the socket.