Small reduction in connect overhead (#578)

This commit is contained in:
J. Nick Koston 2023-10-15 12:01:00 -10:00 committed by GitHub
parent 4be79eab88
commit c5f4bfa561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 42 deletions

View File

@ -11,14 +11,36 @@ cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef bint TYPE_CHECKING
cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE
cdef object asyncio_timeout
cdef object CancelledError
cdef object asyncio_TimeoutError
cdef object ConnectResponse
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest
cdef object GetTimeRequest, GetTimeResponse
cdef object APIVersion
cdef object partial
cdef object hr
cdef object RESOLVE_TIMEOUT
cdef object CONNECT_AND_SETUP_TIMEOUT
cdef object APIConnectionError
cdef object BadNameAPIError
cdef object HandshakeAPIError
cdef object PingFailedAPIError
cdef object ReadFailedAPIError
cdef object TimeoutAPIError
cdef class APIConnection:
cdef object _params

View File

@ -7,6 +7,11 @@ import logging
import socket
import sys
import time
# After we drop support for Python 3.10, we can use the built-in TimeoutError
# instead of the one from asyncio since they are the same in Python 3.11+
from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError
from collections.abc import Coroutine
from dataclasses import astuple, dataclass
from functools import partial
@ -60,6 +65,7 @@ BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()
@ -187,8 +193,9 @@ class APIConnection:
self._ping_timer: asyncio.TimerHandle | None = None
self._pong_timer: asyncio.TimerHandle | None = None
self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
keepalive = params.keepalive
self._keep_alive_interval = keepalive
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._start_connect_task: asyncio.Task[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None
@ -209,7 +216,7 @@ class APIConnection:
Safe to call multiple times.
"""
if self.connection_state == ConnectionState.CLOSED:
if self.connection_state is ConnectionState.CLOSED:
return
was_connected = self.is_connected
self._set_connection_state(ConnectionState.CLOSED)
@ -249,7 +256,7 @@ class APIConnection:
self._ping_timer.cancel()
self._ping_timer = None
if self.on_stop and was_connected:
if self.on_stop is not None and was_connected:
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
@ -277,22 +284,21 @@ class APIConnection:
)
async with asyncio_timeout(RESOLVE_TIMEOUT):
return await coro
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
raise ResolveAPIError(
f"Timeout while resolving IP address for {self.log_name}"
) from err
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
"""Step 2 in connect process: connect the socket."""
self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto
)
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
except OSError as err:
_LOGGER.warning(
"%s: Failed to set socket receive buffer size: %s",
@ -300,7 +306,7 @@ class APIConnection:
err,
)
if self._debug_enabled():
if debug_enable is True:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
@ -311,14 +317,15 @@ class APIConnection:
sockaddr = astuple(addr.sockaddr)
try:
coro = self._loop.sock_connect(self._socket, sockaddr)
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
await coro
except asyncio.TimeoutError as err:
await self._loop.sock_connect(sock, sockaddr)
except asyncio_TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
self._socket = sock
if debug_enable is True:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
@ -333,7 +340,7 @@ class APIConnection:
loop = self._loop
assert self._socket is not None
if self._params.noise_psk is None:
if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
@ -345,11 +352,9 @@ class APIConnection:
)
else:
# Ensure noise_psk is a string and not an EStr
noise_psk = str(self._params.noise_psk)
assert noise_psk is not None
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper(
noise_psk=noise_psk,
noise_psk=str(noise_psk),
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
@ -359,14 +364,14 @@ class APIConnection:
sock=self._socket,
)
self._frame_helper = fh
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
self._frame_helper = fh
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
@ -433,7 +438,7 @@ class APIConnection:
self._pong_timer = loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received
)
elif self._debug_enabled():
elif self._debug_enabled() is True:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
@ -500,11 +505,11 @@ class APIConnection:
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await start_connect_task
except (Exception, asyncio.CancelledError) as ex:
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
@ -547,11 +552,11 @@ class APIConnection:
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await self._finish_connect_task
except (Exception, asyncio.CancelledError) as ex:
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
@ -567,8 +572,8 @@ class APIConnection:
def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change."""
self.connection_state = state
self.is_connected = state == ConnectionState.CONNECTED
self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE
self.is_connected = state is ConnectionState.CONNECTED
self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE
async def _login(self) -> None:
"""Send a login (ConnectRequest) and await the response."""
@ -606,7 +611,7 @@ class APIConnection:
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")
if self._debug_enabled():
if self._debug_enabled() is True:
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
if TYPE_CHECKING:
@ -667,7 +672,7 @@ class APIConnection:
"""Handle a timeout."""
if fut.done():
return
fut.set_exception(asyncio.TimeoutError)
fut.set_exception(asyncio_TimeoutError)
def _handle_complex_message(
self,
@ -727,7 +732,7 @@ class APIConnection:
timeout_expired = False
try:
await fut
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
timeout_expired = True
raise TimeoutAPIError(
f"Timeout waiting for response for {type(send_msg)} after {timeout}s"
@ -761,7 +766,7 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if not self._expected_disconnect and not self._fatal_exception:
if self._expected_disconnect is False and not self._fatal_exception:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
@ -806,7 +811,7 @@ class APIConnection:
msg_type = type(msg)
if self._debug_enabled():
if self._debug_enabled() is True:
_LOGGER.debug(
"%s: Got message of type %s: %s",
self.log_name,
@ -830,7 +835,7 @@ class APIConnection:
handler(msg)
if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse())
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
self._expected_disconnect = True
self._cleanup()
elif msg_type is PingRequest: