mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Small reduction in connect overhead (#578)
This commit is contained in:
parent
4be79eab88
commit
c5f4bfa561
@ -11,14 +11,36 @@ cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef object DISCONNECT_REQUEST_MESSAGE
|
||||
cdef object DISCONNECT_RESPONSE_MESSAGE
|
||||
cdef object PING_REQUEST_MESSAGE
|
||||
cdef object PING_RESPONSE_MESSAGE
|
||||
|
||||
cdef object asyncio_timeout
|
||||
cdef object CancelledError
|
||||
cdef object asyncio_TimeoutError
|
||||
|
||||
cdef object ConnectResponse
|
||||
cdef object DisconnectRequest
|
||||
cdef object PingRequest
|
||||
cdef object GetTimeRequest
|
||||
cdef object GetTimeRequest, GetTimeResponse
|
||||
|
||||
cdef object APIVersion
|
||||
|
||||
cdef object partial
|
||||
|
||||
cdef object hr
|
||||
|
||||
cdef object RESOLVE_TIMEOUT
|
||||
cdef object CONNECT_AND_SETUP_TIMEOUT
|
||||
|
||||
cdef object APIConnectionError
|
||||
cdef object BadNameAPIError
|
||||
cdef object HandshakeAPIError
|
||||
cdef object PingFailedAPIError
|
||||
cdef object ReadFailedAPIError
|
||||
cdef object TimeoutAPIError
|
||||
|
||||
|
||||
cdef class APIConnection:
|
||||
|
||||
cdef object _params
|
||||
|
@ -7,6 +7,11 @@ import logging
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
# After we drop support for Python 3.10, we can use the built-in TimeoutError
|
||||
# instead of the one from asyncio since they are the same in Python 3.11+
|
||||
from asyncio import CancelledError
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
@ -60,6 +65,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
|
||||
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
||||
|
||||
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
||||
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
|
||||
PING_REQUEST_MESSAGE = PingRequest()
|
||||
PING_RESPONSE_MESSAGE = PingResponse()
|
||||
|
||||
@ -187,8 +193,9 @@ class APIConnection:
|
||||
|
||||
self._ping_timer: asyncio.TimerHandle | None = None
|
||||
self._pong_timer: asyncio.TimerHandle | None = None
|
||||
self._keep_alive_interval = params.keepalive
|
||||
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||
keepalive = params.keepalive
|
||||
self._keep_alive_interval = keepalive
|
||||
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||
|
||||
self._start_connect_task: asyncio.Task[None] | None = None
|
||||
self._finish_connect_task: asyncio.Task[None] | None = None
|
||||
@ -209,7 +216,7 @@ class APIConnection:
|
||||
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self.connection_state == ConnectionState.CLOSED:
|
||||
if self.connection_state is ConnectionState.CLOSED:
|
||||
return
|
||||
was_connected = self.is_connected
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
@ -249,7 +256,7 @@ class APIConnection:
|
||||
self._ping_timer.cancel()
|
||||
self._ping_timer = None
|
||||
|
||||
if self.on_stop and was_connected:
|
||||
if self.on_stop is not None and was_connected:
|
||||
# Ensure on_stop is called only once
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
self.on_stop(self._expected_disconnect),
|
||||
@ -277,22 +284,21 @@ class APIConnection:
|
||||
)
|
||||
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
||||
return await coro
|
||||
except asyncio.TimeoutError as err:
|
||||
except asyncio_TimeoutError as err:
|
||||
raise ResolveAPIError(
|
||||
f"Timeout while resolving IP address for {self.log_name}"
|
||||
) from err
|
||||
|
||||
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
self._socket = socket.socket(
|
||||
family=addr.family, type=addr.type, proto=addr.proto
|
||||
)
|
||||
self._socket.setblocking(False)
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
debug_enable = self._debug_enabled()
|
||||
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# Try to reduce the pressure on esphome device as it measures
|
||||
# ram in bytes and we measure ram in megabytes.
|
||||
try:
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||
except OSError as err:
|
||||
_LOGGER.warning(
|
||||
"%s: Failed to set socket receive buffer size: %s",
|
||||
@ -300,7 +306,7 @@ class APIConnection:
|
||||
err,
|
||||
)
|
||||
|
||||
if self._debug_enabled():
|
||||
if debug_enable is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
@ -311,21 +317,22 @@ class APIConnection:
|
||||
sockaddr = astuple(addr.sockaddr)
|
||||
|
||||
try:
|
||||
coro = self._loop.sock_connect(self._socket, sockaddr)
|
||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||
await coro
|
||||
except asyncio.TimeoutError as err:
|
||||
await self._loop.sock_connect(sock, sockaddr)
|
||||
except asyncio_TimeoutError as err:
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
self._socket = sock
|
||||
if debug_enable is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||
@ -333,7 +340,7 @@ class APIConnection:
|
||||
loop = self._loop
|
||||
assert self._socket is not None
|
||||
|
||||
if self._params.noise_psk is None:
|
||||
if (noise_psk := self._params.noise_psk) is None:
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APIPlaintextFrameHelper(
|
||||
on_pkt=self._process_packet,
|
||||
@ -345,11 +352,9 @@ class APIConnection:
|
||||
)
|
||||
else:
|
||||
# Ensure noise_psk is a string and not an EStr
|
||||
noise_psk = str(self._params.noise_psk)
|
||||
assert noise_psk is not None
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APINoiseFrameHelper(
|
||||
noise_psk=noise_psk,
|
||||
noise_psk=str(noise_psk),
|
||||
expected_name=self._params.expected_name,
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
@ -359,14 +364,14 @@ class APIConnection:
|
||||
sock=self._socket,
|
||||
)
|
||||
|
||||
self._frame_helper = fh
|
||||
try:
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
except asyncio.TimeoutError as err:
|
||||
except asyncio_TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
|
||||
self._frame_helper = fh
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
@ -433,7 +438,7 @@ class APIConnection:
|
||||
self._pong_timer = loop.call_at(
|
||||
now + self._keep_alive_timeout, self._async_pong_not_received
|
||||
)
|
||||
elif self._debug_enabled():
|
||||
elif self._debug_enabled() is True:
|
||||
#
|
||||
# We haven't reached the ping response (pong) timeout yet
|
||||
# and we haven't seen a response to the last ping
|
||||
@ -500,11 +505,11 @@ class APIConnection:
|
||||
# does not have a timeout
|
||||
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
|
||||
await start_connect_task
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._cleanup()
|
||||
if isinstance(ex, asyncio.CancelledError):
|
||||
if isinstance(ex, CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
)
|
||||
@ -547,11 +552,11 @@ class APIConnection:
|
||||
# does not have a timeout
|
||||
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
|
||||
await self._finish_connect_task
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._cleanup()
|
||||
if isinstance(ex, asyncio.CancelledError):
|
||||
if isinstance(ex, CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
)
|
||||
@ -567,8 +572,8 @@ class APIConnection:
|
||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||
"""Set the connection state and log the change."""
|
||||
self.connection_state = state
|
||||
self.is_connected = state == ConnectionState.CONNECTED
|
||||
self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE
|
||||
self.is_connected = state is ConnectionState.CONNECTED
|
||||
self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE
|
||||
|
||||
async def _login(self) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
@ -606,7 +611,7 @@ class APIConnection:
|
||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||
|
||||
if self._debug_enabled():
|
||||
if self._debug_enabled() is True:
|
||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -667,7 +672,7 @@ class APIConnection:
|
||||
"""Handle a timeout."""
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_exception(asyncio.TimeoutError)
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
def _handle_complex_message(
|
||||
self,
|
||||
@ -727,7 +732,7 @@ class APIConnection:
|
||||
timeout_expired = False
|
||||
try:
|
||||
await fut
|
||||
except asyncio.TimeoutError as err:
|
||||
except asyncio_TimeoutError as err:
|
||||
timeout_expired = True
|
||||
raise TimeoutAPIError(
|
||||
f"Timeout waiting for response for {type(send_msg)} after {timeout}s"
|
||||
@ -761,7 +766,7 @@ class APIConnection:
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
if not self._expected_disconnect and not self._fatal_exception:
|
||||
if self._expected_disconnect is False and not self._fatal_exception:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
"%s: Connection error occurred: %s",
|
||||
@ -806,7 +811,7 @@ class APIConnection:
|
||||
|
||||
msg_type = type(msg)
|
||||
|
||||
if self._debug_enabled():
|
||||
if self._debug_enabled() is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s",
|
||||
self.log_name,
|
||||
@ -830,7 +835,7 @@ class APIConnection:
|
||||
handler(msg)
|
||||
|
||||
if msg_type is DisconnectRequest:
|
||||
self.send_message(DisconnectResponse())
|
||||
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
elif msg_type is PingRequest:
|
||||
|
Loading…
Reference in New Issue
Block a user