diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 7db54cc..981f5e1 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -11,14 +11,36 @@ cdef float KEEP_ALIVE_TIMEOUT_RATIO cdef bint TYPE_CHECKING cdef object DISCONNECT_REQUEST_MESSAGE +cdef object DISCONNECT_RESPONSE_MESSAGE cdef object PING_REQUEST_MESSAGE cdef object PING_RESPONSE_MESSAGE +cdef object asyncio_timeout +cdef object CancelledError +cdef object asyncio_TimeoutError + +cdef object ConnectResponse cdef object DisconnectRequest cdef object PingRequest -cdef object GetTimeRequest +cdef object GetTimeRequest, GetTimeResponse + +cdef object APIVersion + cdef object partial +cdef object hr + +cdef object RESOLVE_TIMEOUT +cdef object CONNECT_AND_SETUP_TIMEOUT + +cdef object APIConnectionError +cdef object BadNameAPIError +cdef object HandshakeAPIError +cdef object PingFailedAPIError +cdef object ReadFailedAPIError +cdef object TimeoutAPIError + + cdef class APIConnection: cdef object _params diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 3078985..2a3c82f 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -7,6 +7,11 @@ import logging import socket import sys import time + +# After we drop support for Python 3.10, we can use the built-in TimeoutError +# instead of the one from asyncio since they are the same in Python 3.11+ +from asyncio import CancelledError +from asyncio import TimeoutError as asyncio_TimeoutError from collections.abc import Coroutine from dataclasses import astuple, dataclass from functools import partial @@ -60,6 +65,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest} DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() +DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse() PING_REQUEST_MESSAGE = PingRequest() PING_RESPONSE_MESSAGE = PingResponse() @@ -187,8 +193,9 @@ class APIConnection: self._ping_timer: asyncio.TimerHandle | None = None self._pong_timer: asyncio.TimerHandle | None = None - self._keep_alive_interval = params.keepalive - self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO + keepalive = params.keepalive + self._keep_alive_interval = keepalive + self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO self._start_connect_task: asyncio.Task[None] | None = None self._finish_connect_task: asyncio.Task[None] | None = None @@ -209,7 +216,7 @@ class APIConnection: Safe to call multiple times. """ - if self.connection_state == ConnectionState.CLOSED: + if self.connection_state is ConnectionState.CLOSED: return was_connected = self.is_connected self._set_connection_state(ConnectionState.CLOSED) @@ -249,7 +256,7 @@ class APIConnection: self._ping_timer.cancel() self._ping_timer = None - if self.on_stop and was_connected: + if self.on_stop is not None and was_connected: # Ensure on_stop is called only once self._on_stop_task = asyncio.create_task( self.on_stop(self._expected_disconnect), @@ -277,22 +284,21 @@ class APIConnection: ) async with asyncio_timeout(RESOLVE_TIMEOUT): return await coro - except asyncio.TimeoutError as err: + except asyncio_TimeoutError as err: raise ResolveAPIError( f"Timeout while resolving IP address for {self.log_name}" ) from err async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None: """Step 2 in connect process: connect the socket.""" - self._socket = socket.socket( - family=addr.family, type=addr.type, proto=addr.proto - ) - self._socket.setblocking(False) - self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + debug_enable = self._debug_enabled() + sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto) + sock.setblocking(False) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # Try to reduce the pressure on esphome device as it measures # ram in bytes and we measure ram in megabytes. try: - self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) except OSError as err: _LOGGER.warning( "%s: Failed to set socket receive buffer size: %s", @@ -300,7 +306,7 @@ class APIConnection: err, ) - if self._debug_enabled(): + if debug_enable is True: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, @@ -311,21 +317,22 @@ class APIConnection: sockaddr = astuple(addr.sockaddr) try: - coro = self._loop.sock_connect(self._socket, sockaddr) async with asyncio_timeout(TCP_CONNECT_TIMEOUT): - await coro - except asyncio.TimeoutError as err: + await self._loop.sock_connect(sock, sockaddr) + except asyncio_TimeoutError as err: raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err except OSError as err: raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err - _LOGGER.debug( - "%s: Opened socket to %s:%s (%s)", - self.log_name, - self._params.address, - self._params.port, - addr, - ) + self._socket = sock + if debug_enable is True: + _LOGGER.debug( + "%s: Opened socket to %s:%s (%s)", + self.log_name, + self._params.address, + self._params.port, + addr, + ) async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" @@ -333,7 +340,7 @@ class APIConnection: loop = self._loop assert self._socket is not None - if self._params.noise_psk is None: + if (noise_psk := self._params.noise_psk) is None: _, fh = await loop.create_connection( # type: ignore[type-var] lambda: APIPlaintextFrameHelper( on_pkt=self._process_packet, @@ -345,11 +352,9 @@ class APIConnection: ) else: # Ensure noise_psk is a string and not an EStr - noise_psk = str(self._params.noise_psk) - assert noise_psk is not None _, fh = await loop.create_connection( # type: ignore[type-var] lambda: APINoiseFrameHelper( - noise_psk=noise_psk, + noise_psk=str(noise_psk), expected_name=self._params.expected_name, on_pkt=self._process_packet, on_error=self._report_fatal_error, @@ -359,14 +364,14 @@ class APIConnection: sock=self._socket, ) - self._frame_helper = fh try: await fh.perform_handshake(HANDSHAKE_TIMEOUT) - except asyncio.TimeoutError as err: + except asyncio_TimeoutError as err: raise TimeoutAPIError("Handshake timed out") from err except OSError as err: raise HandshakeAPIError(f"Handshake failed: {err}") from err self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE) + self._frame_helper = fh async def _connect_hello(self) -> None: """Step 4 in connect process: send hello and get api version.""" @@ -433,7 +438,7 @@ class APIConnection: self._pong_timer = loop.call_at( now + self._keep_alive_timeout, self._async_pong_not_received ) - elif self._debug_enabled(): + elif self._debug_enabled() is True: # # We haven't reached the ping response (pong) timeout yet # and we haven't seen a response to the last ping @@ -500,11 +505,11 @@ class APIConnection: # does not have a timeout async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT): await start_connect_task - except (Exception, asyncio.CancelledError) as ex: + except (Exception, CancelledError) as ex: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError self._cleanup() - if isinstance(ex, asyncio.CancelledError): + if isinstance(ex, CancelledError): raise self._fatal_exception or APIConnectionError( "Connection cancelled" ) @@ -547,11 +552,11 @@ class APIConnection: # does not have a timeout async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT): await self._finish_connect_task - except (Exception, asyncio.CancelledError) as ex: + except (Exception, CancelledError) as ex: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError self._cleanup() - if isinstance(ex, asyncio.CancelledError): + if isinstance(ex, CancelledError): raise self._fatal_exception or APIConnectionError( "Connection cancelled" ) @@ -567,8 +572,8 @@ class APIConnection: def _set_connection_state(self, state: ConnectionState) -> None: """Set the connection state and log the change.""" self.connection_state = state - self.is_connected = state == ConnectionState.CONNECTED - self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE + self.is_connected = state is ConnectionState.CONNECTED + self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE async def _login(self) -> None: """Send a login (ConnectRequest) and await the response.""" @@ -606,7 +611,7 @@ class APIConnection: if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: raise ValueError(f"Message type id not found for type {msg_type}") - if self._debug_enabled(): + if self._debug_enabled() is True: _LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg) if TYPE_CHECKING: @@ -667,7 +672,7 @@ class APIConnection: """Handle a timeout.""" if fut.done(): return - fut.set_exception(asyncio.TimeoutError) + fut.set_exception(asyncio_TimeoutError) def _handle_complex_message( self, @@ -727,7 +732,7 @@ class APIConnection: timeout_expired = False try: await fut - except asyncio.TimeoutError as err: + except asyncio_TimeoutError as err: timeout_expired = True raise TimeoutAPIError( f"Timeout waiting for response for {type(send_msg)} after {timeout}s" @@ -761,7 +766,7 @@ class APIConnection: The connection will be closed, all exception handlers notified. This method does not log the error, the call site should do so. """ - if not self._expected_disconnect and not self._fatal_exception: + if self._expected_disconnect is False and not self._fatal_exception: # Only log the first error _LOGGER.warning( "%s: Connection error occurred: %s", @@ -806,7 +811,7 @@ class APIConnection: msg_type = type(msg) - if self._debug_enabled(): + if self._debug_enabled() is True: _LOGGER.debug( "%s: Got message of type %s: %s", self.log_name, @@ -830,7 +835,7 @@ class APIConnection: handler(msg) if msg_type is DisconnectRequest: - self.send_message(DisconnectResponse()) + self.send_message(DISCONNECT_RESPONSE_MESSAGE) self._expected_disconnect = True self._cleanup() elif msg_type is PingRequest: