Cooldown before reconnecting on expected disconnect (#397)

This commit is contained in:
J. Nick Koston 2023-03-05 18:54:54 -10:00 committed by GitHub
parent 81f6e67038
commit 51d581dd9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 12 deletions

View File

@ -233,17 +233,17 @@ class APIClient:
async def connect( async def connect(
self, self,
on_stop: Optional[Callable[[], Awaitable[None]]] = None, on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
login: bool = False, login: bool = False,
) -> None: ) -> None:
if self._connection is not None: if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!") raise APIConnectionError(f"Already connected to {self._log_name}!")
async def _on_stop() -> None: async def _on_stop(expected_disconnect: bool) -> None:
# Hook into on_stop handler to clear connection when stopped # Hook into on_stop handler to clear connection when stopped
self._connection = None self._connection = None
if on_stop is not None: if on_stop is not None:
await on_stop() await on_stop(expected_disconnect)
self._connection = APIConnection( self._connection = APIConnection(
self._params, _on_stop, log_name=self._log_name self._params, _on_stop, log_name=self._log_name

View File

@ -91,11 +91,11 @@ class APIConnection:
def __init__( def __init__(
self, self,
params: ConnectionParams, params: ConnectionParams,
on_stop: Callable[[], Coroutine[Any, Any, None]], on_stop: Callable[[bool], Coroutine[Any, Any, None]],
log_name: Optional[str] = None, log_name: Optional[str] = None,
) -> None: ) -> None:
self._params = params self._params = params
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
self._on_stop_task: Optional[asyncio.Task[None]] = None self._on_stop_task: Optional[asyncio.Task[None]] = None
self._socket: Optional[socket.socket] = None self._socket: Optional[socket.socket] = None
self._frame_helper: Optional[APIFrameHelper] = None self._frame_helper: Optional[APIFrameHelper] = None
@ -168,7 +168,9 @@ class APIConnection:
self._on_stop_task = None self._on_stop_task = None
# Ensure on_stop is called only once # Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(self.on_stop()) self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect)
)
self._on_stop_task.add_done_callback(_remove_on_stop_task) self._on_stop_task.add_done_callback(_remove_on_stop_task)
self.on_stop = None self.on_stop = None
@ -228,7 +230,13 @@ class APIConnection:
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
_LOGGER.debug("%s: Opened socket", self._params.address) _LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
async def _connect_init_frame_helper(self) -> None: async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""

View File

@ -9,6 +9,8 @@ from .core import APIConnectionError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
EXPECTED_DISCONNECT_COOLDOWN = 3.0
class ReconnectLogic(zeroconf.RecordUpdateListener): class ReconnectLogic(zeroconf.RecordUpdateListener):
"""Reconnectiong logic handler for ESPHome config entries. """Reconnectiong logic handler for ESPHome config entries.
@ -71,7 +73,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return f"{self.name} @ {self._cli.address}" return f"{self.name} @ {self._cli.address}"
return self._cli.address return self._cli.address
async def _on_disconnect(self) -> None: async def _on_disconnect(self, expected_disconnect: bool) -> None:
"""Log and issue callbacks when disconnecting.""" """Log and issue callbacks when disconnecting."""
if self._is_stopped: if self._is_stopped:
return return
@ -79,11 +81,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# So therefore all these connection warnings are logged # So therefore all these connection warnings are logged
# as infos. The "unavailable" logic will still trigger so the # as infos. The "unavailable" logic will still trigger so the
# user knows if the device is not connected. # user knows if the device is not connected.
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name) disconnect_type = "expected" if expected_disconnect else "unexpected"
_LOGGER.info(
"Processing %s disconnect from ESPHome API for %s",
disconnect_type,
self._log_name,
)
# Run disconnect hook # Run disconnect hook
await self._on_disconnect_cb() await self._on_disconnect_cb()
await self._start_zc_listen()
# Reset tries # Reset tries
async with self._tries_lock: async with self._tries_lock:
@ -91,8 +97,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Connected needs to be reset before the reconnect event (opposite order of check) # Connected needs to be reset before the reconnect event (opposite order of check)
async with self._connected_lock: async with self._connected_lock:
self._connected = False self._connected = False
if expected_disconnect:
# If we expected the disconnect we need
# to cooldown before reconnecting in case the remote
# is rebooting so we don't establish a connection right
# before its about to reboot in the event we are too fast.
await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN)
self._reconnect_event.set() self._reconnect_event.set()
# Start listening for zeroconf records
# only after setting the reconnect_event
# since we only want to accept zeroconf records
# after the reconnect has failed.
await self._start_zc_listen()
async def _wait_and_start_reconnect(self) -> None: async def _wait_and_start_reconnect(self) -> None:
"""Wait for exponentially increasing time to issue next reconnect event.""" """Wait for exponentially increasing time to issue next reconnect event."""
async with self._tries_lock: async with self._tries_lock:
@ -197,7 +217,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async with self._wait_task_lock: async with self._wait_task_lock:
if self._wait_task is not None: if self._wait_task is not None:
self._wait_task.cancel() self._wait_task.cancel()
self._wait_task = None self._wait_task = None
await self._stop_zc_listen() await self._stop_zc_listen()
def stop_callback(self) -> None: def stop_callback(self) -> None:

View File

@ -27,7 +27,7 @@ def connection_params() -> ConnectionParams:
@pytest.fixture @pytest.fixture
def conn(connection_params) -> APIConnection: def conn(connection_params) -> APIConnection:
async def on_stop(): async def on_stop(expected_disconnect: bool) -> None:
pass pass
return APIConnection(connection_params, on_stop) return APIConnection(connection_params, on_stop)