mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-06 19:18:03 +01:00
Cooldown before reconnecting on expected disconnect (#397)
This commit is contained in:
parent
81f6e67038
commit
51d581dd9c
@ -233,17 +233,17 @@ class APIClient:
|
|||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
self,
|
self,
|
||||||
on_stop: Optional[Callable[[], Awaitable[None]]] = None,
|
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
|
||||||
login: bool = False,
|
login: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
||||||
|
|
||||||
async def _on_stop() -> None:
|
async def _on_stop(expected_disconnect: bool) -> None:
|
||||||
# Hook into on_stop handler to clear connection when stopped
|
# Hook into on_stop handler to clear connection when stopped
|
||||||
self._connection = None
|
self._connection = None
|
||||||
if on_stop is not None:
|
if on_stop is not None:
|
||||||
await on_stop()
|
await on_stop(expected_disconnect)
|
||||||
|
|
||||||
self._connection = APIConnection(
|
self._connection = APIConnection(
|
||||||
self._params, _on_stop, log_name=self._log_name
|
self._params, _on_stop, log_name=self._log_name
|
||||||
|
@ -91,11 +91,11 @@ class APIConnection:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: ConnectionParams,
|
params: ConnectionParams,
|
||||||
on_stop: Callable[[], Coroutine[Any, Any, None]],
|
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
|
||||||
log_name: Optional[str] = None,
|
log_name: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._params = params
|
self._params = params
|
||||||
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
|
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
|
||||||
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
||||||
self._socket: Optional[socket.socket] = None
|
self._socket: Optional[socket.socket] = None
|
||||||
self._frame_helper: Optional[APIFrameHelper] = None
|
self._frame_helper: Optional[APIFrameHelper] = None
|
||||||
@ -168,7 +168,9 @@ class APIConnection:
|
|||||||
self._on_stop_task = None
|
self._on_stop_task = None
|
||||||
|
|
||||||
# Ensure on_stop is called only once
|
# Ensure on_stop is called only once
|
||||||
self._on_stop_task = asyncio.create_task(self.on_stop())
|
self._on_stop_task = asyncio.create_task(
|
||||||
|
self.on_stop(self._expected_disconnect)
|
||||||
|
)
|
||||||
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
||||||
self.on_stop = None
|
self.on_stop = None
|
||||||
|
|
||||||
@ -228,7 +230,13 @@ class APIConnection:
|
|||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||||
|
|
||||||
_LOGGER.debug("%s: Opened socket", self._params.address)
|
_LOGGER.debug(
|
||||||
|
"%s: Opened socket to %s:%s (%s)",
|
||||||
|
self.log_name,
|
||||||
|
self._params.address,
|
||||||
|
self._params.port,
|
||||||
|
addr,
|
||||||
|
)
|
||||||
|
|
||||||
async def _connect_init_frame_helper(self) -> None:
|
async def _connect_init_frame_helper(self) -> None:
|
||||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
|
@ -9,6 +9,8 @@ from .core import APIConnectionError
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EXPECTED_DISCONNECT_COOLDOWN = 3.0
|
||||||
|
|
||||||
|
|
||||||
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||||
"""Reconnectiong logic handler for ESPHome config entries.
|
"""Reconnectiong logic handler for ESPHome config entries.
|
||||||
@ -71,7 +73,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
return f"{self.name} @ {self._cli.address}"
|
return f"{self.name} @ {self._cli.address}"
|
||||||
return self._cli.address
|
return self._cli.address
|
||||||
|
|
||||||
async def _on_disconnect(self) -> None:
|
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
||||||
"""Log and issue callbacks when disconnecting."""
|
"""Log and issue callbacks when disconnecting."""
|
||||||
if self._is_stopped:
|
if self._is_stopped:
|
||||||
return
|
return
|
||||||
@ -79,11 +81,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# So therefore all these connection warnings are logged
|
# So therefore all these connection warnings are logged
|
||||||
# as infos. The "unavailable" logic will still trigger so the
|
# as infos. The "unavailable" logic will still trigger so the
|
||||||
# user knows if the device is not connected.
|
# user knows if the device is not connected.
|
||||||
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name)
|
disconnect_type = "expected" if expected_disconnect else "unexpected"
|
||||||
|
_LOGGER.info(
|
||||||
|
"Processing %s disconnect from ESPHome API for %s",
|
||||||
|
disconnect_type,
|
||||||
|
self._log_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Run disconnect hook
|
# Run disconnect hook
|
||||||
await self._on_disconnect_cb()
|
await self._on_disconnect_cb()
|
||||||
await self._start_zc_listen()
|
|
||||||
|
|
||||||
# Reset tries
|
# Reset tries
|
||||||
async with self._tries_lock:
|
async with self._tries_lock:
|
||||||
@ -91,8 +97,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# Connected needs to be reset before the reconnect event (opposite order of check)
|
# Connected needs to be reset before the reconnect event (opposite order of check)
|
||||||
async with self._connected_lock:
|
async with self._connected_lock:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
|
if expected_disconnect:
|
||||||
|
# If we expected the disconnect we need
|
||||||
|
# to cooldown before reconnecting in case the remote
|
||||||
|
# is rebooting so we don't establish a connection right
|
||||||
|
# before its about to reboot in the event we are too fast.
|
||||||
|
await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN)
|
||||||
|
|
||||||
self._reconnect_event.set()
|
self._reconnect_event.set()
|
||||||
|
|
||||||
|
# Start listening for zeroconf records
|
||||||
|
# only after setting the reconnect_event
|
||||||
|
# since we only want to accept zeroconf records
|
||||||
|
# after the reconnect has failed.
|
||||||
|
await self._start_zc_listen()
|
||||||
|
|
||||||
async def _wait_and_start_reconnect(self) -> None:
|
async def _wait_and_start_reconnect(self) -> None:
|
||||||
"""Wait for exponentially increasing time to issue next reconnect event."""
|
"""Wait for exponentially increasing time to issue next reconnect event."""
|
||||||
async with self._tries_lock:
|
async with self._tries_lock:
|
||||||
|
@ -27,7 +27,7 @@ def connection_params() -> ConnectionParams:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def conn(connection_params) -> APIConnection:
|
def conn(connection_params) -> APIConnection:
|
||||||
async def on_stop():
|
async def on_stop(expected_disconnect: bool) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return APIConnection(connection_params, on_stop)
|
return APIConnection(connection_params, on_stop)
|
||||||
|
Loading…
Reference in New Issue
Block a user