mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-31 18:17:46 +01:00
Cooldown before reconnecting on expected disconnect (#397)
This commit is contained in:
parent
81f6e67038
commit
51d581dd9c
@ -233,17 +233,17 @@ class APIClient:
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
on_stop: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
|
||||
login: bool = False,
|
||||
) -> None:
|
||||
if self._connection is not None:
|
||||
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
||||
|
||||
async def _on_stop() -> None:
|
||||
async def _on_stop(expected_disconnect: bool) -> None:
|
||||
# Hook into on_stop handler to clear connection when stopped
|
||||
self._connection = None
|
||||
if on_stop is not None:
|
||||
await on_stop()
|
||||
await on_stop(expected_disconnect)
|
||||
|
||||
self._connection = APIConnection(
|
||||
self._params, _on_stop, log_name=self._log_name
|
||||
|
@ -91,11 +91,11 @@ class APIConnection:
|
||||
def __init__(
|
||||
self,
|
||||
params: ConnectionParams,
|
||||
on_stop: Callable[[], Coroutine[Any, Any, None]],
|
||||
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
|
||||
log_name: Optional[str] = None,
|
||||
) -> None:
|
||||
self._params = params
|
||||
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
|
||||
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
|
||||
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
@ -168,7 +168,9 @@ class APIConnection:
|
||||
self._on_stop_task = None
|
||||
|
||||
# Ensure on_stop is called only once
|
||||
self._on_stop_task = asyncio.create_task(self.on_stop())
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
self.on_stop(self._expected_disconnect)
|
||||
)
|
||||
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
||||
self.on_stop = None
|
||||
|
||||
@ -228,7 +230,13 @@ class APIConnection:
|
||||
except asyncio.TimeoutError as err:
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||
|
||||
_LOGGER.debug("%s: Opened socket", self._params.address)
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||
|
@ -9,6 +9,8 @@ from .core import APIConnectionError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
EXPECTED_DISCONNECT_COOLDOWN = 3.0
|
||||
|
||||
|
||||
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""Reconnectiong logic handler for ESPHome config entries.
|
||||
@ -71,7 +73,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
return f"{self.name} @ {self._cli.address}"
|
||||
return self._cli.address
|
||||
|
||||
async def _on_disconnect(self) -> None:
|
||||
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
||||
"""Log and issue callbacks when disconnecting."""
|
||||
if self._is_stopped:
|
||||
return
|
||||
@ -79,11 +81,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# So therefore all these connection warnings are logged
|
||||
# as infos. The "unavailable" logic will still trigger so the
|
||||
# user knows if the device is not connected.
|
||||
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name)
|
||||
disconnect_type = "expected" if expected_disconnect else "unexpected"
|
||||
_LOGGER.info(
|
||||
"Processing %s disconnect from ESPHome API for %s",
|
||||
disconnect_type,
|
||||
self._log_name,
|
||||
)
|
||||
|
||||
# Run disconnect hook
|
||||
await self._on_disconnect_cb()
|
||||
await self._start_zc_listen()
|
||||
|
||||
# Reset tries
|
||||
async with self._tries_lock:
|
||||
@ -91,8 +97,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Connected needs to be reset before the reconnect event (opposite order of check)
|
||||
async with self._connected_lock:
|
||||
self._connected = False
|
||||
|
||||
if expected_disconnect:
|
||||
# If we expected the disconnect we need
|
||||
# to cooldown before reconnecting in case the remote
|
||||
# is rebooting so we don't establish a connection right
|
||||
# before its about to reboot in the event we are too fast.
|
||||
await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN)
|
||||
|
||||
self._reconnect_event.set()
|
||||
|
||||
# Start listening for zeroconf records
|
||||
# only after setting the reconnect_event
|
||||
# since we only want to accept zeroconf records
|
||||
# after the reconnect has failed.
|
||||
await self._start_zc_listen()
|
||||
|
||||
async def _wait_and_start_reconnect(self) -> None:
|
||||
"""Wait for exponentially increasing time to issue next reconnect event."""
|
||||
async with self._tries_lock:
|
||||
@ -197,7 +217,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
async with self._wait_task_lock:
|
||||
if self._wait_task is not None:
|
||||
self._wait_task.cancel()
|
||||
self._wait_task = None
|
||||
self._wait_task = None
|
||||
await self._stop_zc_listen()
|
||||
|
||||
def stop_callback(self) -> None:
|
||||
|
@ -27,7 +27,7 @@ def connection_params() -> ConnectionParams:
|
||||
|
||||
@pytest.fixture
|
||||
def conn(connection_params) -> APIConnection:
|
||||
async def on_stop():
|
||||
async def on_stop(expected_disconnect: bool) -> None:
|
||||
pass
|
||||
|
||||
return APIConnection(connection_params, on_stop)
|
||||
|
Loading…
Reference in New Issue
Block a user