Adjust ping timeout to prevent premature disconnections (#395)

This commit is contained in:
J. Nick Koston 2023-03-05 18:56:22 -10:00 committed by GitHub
parent 51d581dd9c
commit 0327f75414
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 116 additions and 55 deletions

View File

@ -165,6 +165,16 @@ _LOGGER = logging.getLogger(__name__)
DEFAULT_BLE_TIMEOUT = 30.0 DEFAULT_BLE_TIMEOUT = 30.0
DEFAULT_BLE_DISCONNECT_TIMEOUT = 5.0 DEFAULT_BLE_DISCONNECT_TIMEOUT = 5.0
# We send a ping every 20 seconds, and the timeout ratio is 4.5x the
# ping interval. This means that if we don't receive a ping for 90.0
# seconds, we'll consider the connection dead and reconnect.
#
# This was chosen because the 20s is around the expected time for a
# device to reboot and reconnect to wifi, and 90 seconds is the absolute
# maximum time a device can take to respond when its behind + the WiFi
# connection is poor.
KEEP_ALIVE_FREQUENCY = 20.0
ExecuteServiceDataType = Dict[ ExecuteServiceDataType = Dict[
str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]] str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]]
] ]
@ -179,7 +189,7 @@ class APIClient:
password: Optional[str], password: Optional[str],
*, *,
client_info: str = "aioesphomeapi", client_info: str = "aioesphomeapi",
keepalive: float = 15.0, keepalive: float = KEEP_ALIVE_FREQUENCY,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType = None,
noise_psk: Optional[str] = None, noise_psk: Optional[str] = None,
expected_name: Optional[str] = None, expected_name: Optional[str] = None,

View File

@ -50,10 +50,32 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, PingResponse, DisconnectRequest}
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
#
# We use 4.5x the keep-alive time as the timeout for the pong
# since the default ping interval is 20s which is about the time
# a device takes to reboot and reconnect to the network making
# the maximum time it has to respond to a ping at 90s which is
# enough time to know that the device has truly disconnected
# from the network.
#
HANDSHAKE_TIMEOUT = 30.0
RESOLVE_TIMEOUT = 30.0
CONNECT_REQUEST_TIMEOUT = 30.0
# The connect timeout should be the maximum time we expect the esp to take
# to reboot and connect to the network/WiFi.
CONNECT_TIMEOUT = 60.0
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar( in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
"in_do_connect" "in_do_connect"
) )
@ -115,10 +137,12 @@ class APIConnection:
# Handlers currently subscribed to exceptions in the read task # Handlers currently subscribed to exceptions in the read task
self._read_exception_handlers: List[Callable[[Exception], None]] = [] self._read_exception_handlers: List[Callable[[Exception], None]] = []
self._ping_stop_event = asyncio.Event() self._ping_timer: Optional[asyncio.TimerHandle] = None
self._pong_timer: Optional[asyncio.TimerHandle] = None
self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._connect_task: Optional[asyncio.Task[None]] = None self._connect_task: Optional[asyncio.Task[None]] = None
self._keep_alive_task: Optional[asyncio.Task[None]] = None
self._fatal_exception: Optional[Exception] = None self._fatal_exception: Optional[Exception] = None
self._expected_disconnect = False self._expected_disconnect = False
@ -144,10 +168,6 @@ class APIConnection:
self._connect_task.cancel() self._connect_task.cancel()
self._connect_task = None self._connect_task = None
if self._keep_alive_task is not None:
self._keep_alive_task.cancel()
self._keep_alive_task = None
if self._frame_helper is not None: if self._frame_helper is not None:
self._frame_helper.close() self._frame_helper.close()
self._frame_helper = None self._frame_helper = None
@ -156,6 +176,12 @@ class APIConnection:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
self._async_cancel_pong_timer()
if self._ping_timer is not None:
self._ping_timer.cancel()
self._ping_timer = None
if self.on_stop and self._connect_complete: if self.on_stop and self._connect_complete:
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None: def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None:
@ -174,11 +200,6 @@ class APIConnection:
self._on_stop_task.add_done_callback(_remove_on_stop_task) self._on_stop_task.add_done_callback(_remove_on_stop_task)
self.on_stop = None self.on_stop = None
# Note: we don't explicitly cancel the ping/read task here
# That's because if not written right the ping/read task could cancel
# themselves, effectively ending execution after _cleanup which may be unexpected
self._ping_stop_event.set()
async def _connect_resolve_host(self) -> hr.AddrInfo: async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address.""" """Step 1 in connect process: resolve the address."""
try: try:
@ -187,7 +208,7 @@ class APIConnection:
self._params.port, self._params.port,
self._params.zeroconf_instance, self._params.zeroconf_instance,
) )
async with async_timeout.timeout(30.0): async with async_timeout.timeout(RESOLVE_TIMEOUT):
return await coro return await coro
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise ResolveAPIError( raise ResolveAPIError(
@ -223,7 +244,7 @@ class APIConnection:
try: try:
coro = asyncio.get_event_loop().sock_connect(self._socket, sockaddr) coro = asyncio.get_event_loop().sock_connect(self._socket, sockaddr)
async with async_timeout.timeout(30.0): async with async_timeout.timeout(CONNECT_TIMEOUT):
await coro await coro
except OSError as err: except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
@ -265,7 +286,7 @@ class APIConnection:
self._frame_helper = fh self._frame_helper = fh
self._connection_state = ConnectionState.SOCKET_OPENED self._connection_state = ConnectionState.SOCKET_OPENED
try: try:
async with async_timeout.timeout(30.0): async with async_timeout.timeout(HANDSHAKE_TIMEOUT):
await fh.perform_handshake() await fh.perform_handshake()
except OSError as err: except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err raise HandshakeAPIError(f"Handshake failed: {err}") from err
@ -310,40 +331,68 @@ class APIConnection:
async def _connect_start_ping(self) -> None: async def _connect_start_ping(self) -> None:
"""Step 5 in connect process: start the ping loop.""" """Step 5 in connect process: start the ping loop."""
self._async_schedule_keep_alive(asyncio.get_running_loop())
async def _keep_alive_loop() -> None: def _async_schedule_keep_alive(self, loop: asyncio.AbstractEventLoop) -> None:
while self._is_socket_open: """Start the keep alive task."""
# Wait for keepalive seconds, or ping stop event, whichever happens first self._ping_timer = loop.call_later(
try: self._keep_alive_interval, self._async_send_keep_alive
async with async_timeout.timeout(self._params.keepalive): )
await self._ping_stop_event.wait()
except asyncio.TimeoutError:
pass
# Re-check connection state def _async_send_keep_alive(self) -> None:
if not self._is_socket_open or self._ping_stop_event.is_set(): """Send a keep alive message."""
return if not self._is_socket_open:
return
try: loop = asyncio.get_running_loop()
await self._ping() self.send_message(PING_REQUEST_MESSAGE)
except TimeoutAPIError:
_LOGGER.debug("%s: Ping timed out!", self.log_name)
self._report_fatal_error(PingFailedAPIError())
return
except APIConnectionError as err:
_LOGGER.debug("%s: Ping Failed: %s", self.log_name, err)
self._report_fatal_error(err)
return
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(
"%s: Unexpected error during ping:",
self.log_name,
exc_info=True,
)
self._report_fatal_error(err)
return
self._keep_alive_task = asyncio.create_task(_keep_alive_loop()) if self._pong_timer is None:
# Do not reset the timer if it's already set
# since the only thing we want to reset the timer
# is if we receive a pong.
self._pong_timer = loop.call_later(
self._keep_alive_timeout, self._async_pong_not_received
)
else:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
#
# We send another ping in case the device has
# rebooted and dropped the connection without telling
# us to force a TCP RST aka connection reset by peer.
#
_LOGGER.debug(
"%s: PingResponse (pong) was not received "
"since last keep alive after %s seconds; "
"rescheduling keep alive",
self.log_name,
self._keep_alive_interval,
)
self._async_schedule_keep_alive(loop)
def _async_cancel_pong_timer(self) -> None:
"""Cancel the pong timer."""
if self._pong_timer is not None:
self._pong_timer.cancel()
self._pong_timer = None
def _async_pong_not_received(self) -> None:
"""Ping not received."""
if not self._is_socket_open:
return
_LOGGER.debug(
"%s: Ping response not received after %s seconds",
self.log_name,
self._keep_alive_timeout,
)
self._report_fatal_error(
PingFailedAPIError(
f"Ping response not received after {self._keep_alive_timeout} seconds"
)
)
async def connect(self, *, login: bool) -> None: async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED: if self._connection_state != ConnectionState.INITIALIZED:
@ -398,7 +447,9 @@ class APIConnection:
if self._params.password is not None: if self._params.password is not None:
connect.password = self._params.password connect.password = self._params.password
try: try:
resp = await self.send_message_await_response(connect, ConnectResponse) resp = await self.send_message_await_response(
connect, ConnectResponse, timeout=CONNECT_REQUEST_TIMEOUT
)
except TimeoutAPIError as err: except TimeoutAPIError as err:
# After a timeout for connect the connection can no longer be used # After a timeout for connect the connection can no longer be used
# We don't know what state the device may be in after ConnectRequest # We don't know what state the device may be in after ConnectRequest
@ -636,22 +687,22 @@ class APIConnection:
if msg_type not in INTERNAL_MESSAGE_TYPES: if msg_type not in INTERNAL_MESSAGE_TYPES:
return return
if isinstance(msg, DisconnectRequest): if msg_type is PingResponse:
# We got a pong so we know the ESP is alive, cancel the timer
# that will disconnect us
self._async_cancel_pong_timer()
elif msg_type is DisconnectRequest:
self.send_message(DisconnectResponse()) self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED self._connection_state = ConnectionState.CLOSED
self._expected_disconnect = True self._expected_disconnect = True
self._cleanup() self._cleanup()
elif isinstance(msg, PingRequest): elif msg_type is PingRequest:
self.send_message(PingResponse()) self.send_message(PING_RESPONSE_MESSAGE)
elif isinstance(msg, GetTimeRequest): elif msg_type is GetTimeRequest:
resp = GetTimeResponse() resp = GetTimeResponse()
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
self.send_message(resp) self.send_message(resp)
async def _ping(self) -> None:
self._check_connected()
await self.send_message_await_response(PingRequest(), PingResponse)
async def disconnect(self) -> None: async def disconnect(self) -> None:
if not self._is_socket_open or not self._frame_helper: if not self._is_socket_open or not self._frame_helper:
# We still want to send a disconnect request even # We still want to send a disconnect request even