Cooldown before reconnecting on expected disconnect (#397)

This commit is contained in:
J. Nick Koston 2023-03-05 18:54:54 -10:00 committed by GitHub
parent 81f6e67038
commit 51d581dd9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 12 deletions

View File

@ -233,17 +233,17 @@ class APIClient:
async def connect(
self,
on_stop: Optional[Callable[[], Awaitable[None]]] = None,
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
login: bool = False,
) -> None:
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!")
async def _on_stop() -> None:
async def _on_stop(expected_disconnect: bool) -> None:
# Hook into on_stop handler to clear connection when stopped
self._connection = None
if on_stop is not None:
await on_stop()
await on_stop(expected_disconnect)
self._connection = APIConnection(
self._params, _on_stop, log_name=self._log_name

View File

@ -91,11 +91,11 @@ class APIConnection:
def __init__(
self,
params: ConnectionParams,
on_stop: Callable[[], Coroutine[Any, Any, None]],
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
log_name: Optional[str] = None,
) -> None:
self._params = params
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
self._on_stop_task: Optional[asyncio.Task[None]] = None
self._socket: Optional[socket.socket] = None
self._frame_helper: Optional[APIFrameHelper] = None
@ -168,7 +168,9 @@ class APIConnection:
self._on_stop_task = None
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(self.on_stop())
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect)
)
self._on_stop_task.add_done_callback(_remove_on_stop_task)
self.on_stop = None
@ -228,7 +230,13 @@ class APIConnection:
except asyncio.TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
_LOGGER.debug("%s: Opened socket", self._params.address)
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""

View File

@ -9,6 +9,8 @@ from .core import APIConnectionError
_LOGGER = logging.getLogger(__name__)
EXPECTED_DISCONNECT_COOLDOWN = 3.0
class ReconnectLogic(zeroconf.RecordUpdateListener):
"""Reconnectiong logic handler for ESPHome config entries.
@ -71,7 +73,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return f"{self.name} @ {self._cli.address}"
return self._cli.address
async def _on_disconnect(self) -> None:
async def _on_disconnect(self, expected_disconnect: bool) -> None:
"""Log and issue callbacks when disconnecting."""
if self._is_stopped:
return
@ -79,11 +81,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# So therefore all these connection warnings are logged
# as infos. The "unavailable" logic will still trigger so the
# user knows if the device is not connected.
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name)
disconnect_type = "expected" if expected_disconnect else "unexpected"
_LOGGER.info(
"Processing %s disconnect from ESPHome API for %s",
disconnect_type,
self._log_name,
)
# Run disconnect hook
await self._on_disconnect_cb()
await self._start_zc_listen()
# Reset tries
async with self._tries_lock:
@ -91,8 +97,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Connected needs to be reset before the reconnect event (opposite order of check)
async with self._connected_lock:
self._connected = False
if expected_disconnect:
# If we expected the disconnect we need
# to cooldown before reconnecting in case the remote
# is rebooting so we don't establish a connection right
# before its about to reboot in the event we are too fast.
await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN)
self._reconnect_event.set()
# Start listening for zeroconf records
# only after setting the reconnect_event
# since we only want to accept zeroconf records
# after the reconnect has failed.
await self._start_zc_listen()
async def _wait_and_start_reconnect(self) -> None:
"""Wait for exponentially increasing time to issue next reconnect event."""
async with self._tries_lock:
@ -197,7 +217,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async with self._wait_task_lock:
if self._wait_task is not None:
self._wait_task.cancel()
self._wait_task = None
self._wait_task = None
await self._stop_zc_listen()
def stop_callback(self) -> None:

View File

@ -27,7 +27,7 @@ def connection_params() -> ConnectionParams:
@pytest.fixture
def conn(connection_params) -> APIConnection:
async def on_stop():
async def on_stop(expected_disconnect: bool) -> None:
pass
return APIConnection(connection_params, on_stop)