diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index fb46293..ea15539 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -233,17 +233,17 @@ class APIClient: async def connect( self, - on_stop: Optional[Callable[[], Awaitable[None]]] = None, + on_stop: Optional[Callable[[bool], Awaitable[None]]] = None, login: bool = False, ) -> None: if self._connection is not None: raise APIConnectionError(f"Already connected to {self._log_name}!") - async def _on_stop() -> None: + async def _on_stop(expected_disconnect: bool) -> None: # Hook into on_stop handler to clear connection when stopped self._connection = None if on_stop is not None: - await on_stop() + await on_stop(expected_disconnect) self._connection = APIConnection( self._params, _on_stop, log_name=self._log_name diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index d530a5a..45e4e6b 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -91,11 +91,11 @@ class APIConnection: def __init__( self, params: ConnectionParams, - on_stop: Callable[[], Coroutine[Any, Any, None]], + on_stop: Callable[[bool], Coroutine[Any, Any, None]], log_name: Optional[str] = None, ) -> None: self._params = params - self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop + self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop self._on_stop_task: Optional[asyncio.Task[None]] = None self._socket: Optional[socket.socket] = None self._frame_helper: Optional[APIFrameHelper] = None @@ -168,7 +168,9 @@ class APIConnection: self._on_stop_task = None # Ensure on_stop is called only once - self._on_stop_task = asyncio.create_task(self.on_stop()) + self._on_stop_task = asyncio.create_task( + self.on_stop(self._expected_disconnect) + ) self._on_stop_task.add_done_callback(_remove_on_stop_task) self.on_stop = None @@ -228,7 +230,13 @@ class APIConnection: except asyncio.TimeoutError as err: raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err - _LOGGER.debug("%s: Opened socket", self._params.address) + _LOGGER.debug( + "%s: Opened socket to %s:%s (%s)", + self.log_name, + self._params.address, + self._params.port, + addr, + ) async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index 717631c..6fd8e6b 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -9,6 +9,8 @@ from .core import APIConnectionError _LOGGER = logging.getLogger(__name__) +EXPECTED_DISCONNECT_COOLDOWN = 3.0 + class ReconnectLogic(zeroconf.RecordUpdateListener): """Reconnectiong logic handler for ESPHome config entries. @@ -71,7 +73,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): return f"{self.name} @ {self._cli.address}" return self._cli.address - async def _on_disconnect(self) -> None: + async def _on_disconnect(self, expected_disconnect: bool) -> None: """Log and issue callbacks when disconnecting.""" if self._is_stopped: return @@ -79,11 +81,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # So therefore all these connection warnings are logged # as infos. The "unavailable" logic will still trigger so the # user knows if the device is not connected. - _LOGGER.info("Disconnected from ESPHome API for %s", self._log_name) + disconnect_type = "expected" if expected_disconnect else "unexpected" + _LOGGER.info( + "Processing %s disconnect from ESPHome API for %s", + disconnect_type, + self._log_name, + ) # Run disconnect hook await self._on_disconnect_cb() - await self._start_zc_listen() # Reset tries async with self._tries_lock: @@ -91,8 +97,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # Connected needs to be reset before the reconnect event (opposite order of check) async with self._connected_lock: self._connected = False + + if expected_disconnect: + # If we expected the disconnect we need + # to cooldown before reconnecting in case the remote + # is rebooting so we don't establish a connection right + # before its about to reboot in the event we are too fast. + await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN) + self._reconnect_event.set() + # Start listening for zeroconf records + # only after setting the reconnect_event + # since we only want to accept zeroconf records + # after the reconnect has failed. + await self._start_zc_listen() + async def _wait_and_start_reconnect(self) -> None: """Wait for exponentially increasing time to issue next reconnect event.""" async with self._tries_lock: @@ -197,7 +217,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): async with self._wait_task_lock: if self._wait_task is not None: self._wait_task.cancel() - self._wait_task = None + self._wait_task = None await self._stop_zc_listen() def stop_callback(self) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index e62f79c..38f3444 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -27,7 +27,7 @@ def connection_params() -> ConnectionParams: @pytest.fixture def conn(connection_params) -> APIConnection: - async def on_stop(): + async def on_stop(expected_disconnect: bool) -> None: pass return APIConnection(connection_params, on_stop)