mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-15 01:31:21 +01:00
Refactor reconnect logic to reduce complexity (#426)
This commit is contained in:
parent
3f29ac92ad
commit
de9b7266f1
@ -255,6 +255,11 @@ class APIClient:
|
||||
return f"{self._cached_name} @ {self.address}"
|
||||
return self.address
|
||||
|
||||
def set_cached_name_if_unset(self, name: str) -> None:
|
||||
"""Set the cached name of the device if not set."""
|
||||
if not self._cached_name:
|
||||
self._cached_name = name
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
|
||||
|
@ -190,7 +190,8 @@ class APIConnection:
|
||||
|
||||
# Ensure on_stop is called only once
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
self.on_stop(self._expected_disconnect)
|
||||
self.on_stop(self._expected_disconnect),
|
||||
name=f"{self.log_name} aioesphomeapi connection on_stop",
|
||||
)
|
||||
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
||||
self.on_stop = None
|
||||
@ -405,7 +406,9 @@ class APIConnection:
|
||||
if login:
|
||||
await self.login(check_connected=False)
|
||||
|
||||
self._connect_task = asyncio.create_task(_do_connect())
|
||||
self._connect_task = asyncio.create_task(
|
||||
_do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
)
|
||||
|
||||
try:
|
||||
# Allow 2 minutes for connect; this is only as a last measure
|
||||
|
@ -38,35 +38,26 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
:param on_connect: Coroutine Function to call when connected.
|
||||
:param on_disconnect: Coroutine Function to call when disconnected.
|
||||
"""
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._cli = client
|
||||
self.name = name
|
||||
self._on_connect_cb = on_connect
|
||||
self._on_disconnect_cb = on_disconnect
|
||||
self._on_connect_error_cb = on_connect_error
|
||||
self._zc = zeroconf_instance
|
||||
self._filter_alias: Optional[str] = None
|
||||
# Flag to check if the device is connected
|
||||
self._connected = True
|
||||
self._connected = False
|
||||
self._connected_lock = asyncio.Lock()
|
||||
self._zc_lock = asyncio.Lock()
|
||||
self._is_stopped = True
|
||||
self._zc_listening = False
|
||||
# Event the different strategies use for issuing a reconnect attempt.
|
||||
self._reconnect_event = asyncio.Event()
|
||||
# The task containing the infinite reconnect loop while running
|
||||
self._loop_task: Optional[asyncio.Task[None]] = None
|
||||
# How many reconnect attempts have there been already, used for exponential wait time
|
||||
# How many connect attempts have there been already, used for exponential wait time
|
||||
self._tries = 0
|
||||
self._tries_lock = asyncio.Lock()
|
||||
# Track the wait task to cancel it on shutdown
|
||||
self._wait_task: Optional[asyncio.Task[None]] = None
|
||||
self._wait_task_lock = asyncio.Lock()
|
||||
# Event for tracking when logic should stop
|
||||
self._stop_event = asyncio.Event()
|
||||
self._connect_task: Optional[asyncio.Task[None]] = None
|
||||
self._connect_timer: Optional[asyncio.TimerHandle] = None
|
||||
self._stop_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
@property
|
||||
def _is_stopped(self) -> bool:
|
||||
return self._stop_event.is_set()
|
||||
|
||||
@property
|
||||
def _log_name(self) -> str:
|
||||
if self.name is not None:
|
||||
@ -91,59 +82,25 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Run disconnect hook
|
||||
await self._on_disconnect_cb()
|
||||
|
||||
# Reset tries
|
||||
async with self._tries_lock:
|
||||
self._tries = 0
|
||||
# Connected needs to be reset before the reconnect event (opposite order of check)
|
||||
async with self._connected_lock:
|
||||
self._connected = False
|
||||
|
||||
if expected_disconnect:
|
||||
# If we expected the disconnect we need
|
||||
# to cooldown before reconnecting in case the remote
|
||||
# is rebooting so we don't establish a connection right
|
||||
# before its about to reboot in the event we are too fast.
|
||||
await asyncio.sleep(EXPECTED_DISCONNECT_COOLDOWN)
|
||||
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
||||
# If we expected the disconnect we need
|
||||
# to cooldown before connecting in case the remote
|
||||
# is rebooting so we don't establish a connection right
|
||||
# before its about to reboot in the event we are too fast.
|
||||
self._schedule_connect(wait)
|
||||
|
||||
self._reconnect_event.set()
|
||||
|
||||
# Start listening for zeroconf records
|
||||
# only after setting the reconnect_event
|
||||
# since we only want to accept zeroconf records
|
||||
# after the reconnect has failed.
|
||||
await self._start_zc_listen()
|
||||
|
||||
async def _wait_and_start_reconnect(self) -> None:
|
||||
"""Wait for exponentially increasing time to issue next reconnect event."""
|
||||
async with self._tries_lock:
|
||||
tries = self._tries
|
||||
# If not first re-try, wait and print message
|
||||
# Cap wait time at 1 minute. This is because while working on the
|
||||
# device (e.g. soldering stuff), users don't want to have to wait
|
||||
# a long time for their device to show up in HA again (this was
|
||||
# mentioned a lot in early feedback)
|
||||
tries = min(tries, 10) # prevent OverflowError
|
||||
wait_time = int(round(min(1.8**tries, 60.0)))
|
||||
if tries == 1:
|
||||
_LOGGER.info("Trying to reconnect to %s in the background", self._log_name)
|
||||
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
async with self._wait_task_lock:
|
||||
self._wait_task = None
|
||||
self._reconnect_event.set()
|
||||
|
||||
async def _try_connect(self) -> None:
|
||||
async def _try_connect(self) -> bool:
|
||||
"""Try connecting to the API client."""
|
||||
async with self._tries_lock:
|
||||
tries = self._tries
|
||||
self._tries += 1
|
||||
|
||||
assert self._connected_lock.locked(), "connected_lock must be locked"
|
||||
try:
|
||||
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
level = logging.WARNING if tries == 0 else logging.DEBUG
|
||||
level = logging.WARNING if self._tries == 0 else logging.DEBUG
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s",
|
||||
@ -152,104 +109,119 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Print stacktrace if unhandled (not APIConnectionError)
|
||||
exc_info=not isinstance(err, APIConnectionError),
|
||||
)
|
||||
await self._start_zc_listen()
|
||||
# Schedule re-connect in event loop in order not to delay HA
|
||||
# startup. First connect is scheduled in tracked tasks.
|
||||
async with self._wait_task_lock:
|
||||
# Allow only one wait task at a time
|
||||
# can happen if mDNS record received while waiting, then use existing wait task
|
||||
if self._wait_task is not None:
|
||||
return
|
||||
self._tries += 1
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
self._connected = True
|
||||
self._tries = 0
|
||||
await self._on_connect_cb()
|
||||
return True
|
||||
|
||||
self._wait_task = asyncio.create_task(self._wait_and_start_reconnect())
|
||||
else:
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
async with self._tries_lock:
|
||||
self._tries = 0
|
||||
async with self._connected_lock:
|
||||
self._connected = True
|
||||
await self._stop_zc_listen()
|
||||
await self._on_connect_cb()
|
||||
|
||||
async def _reconnect_once(self) -> None:
|
||||
# Wait and clear reconnection event
|
||||
await self._reconnect_event.wait()
|
||||
self._reconnect_event.clear()
|
||||
|
||||
# If in connected state, do not try to connect again.
|
||||
async with self._connected_lock:
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
if self._is_stopped:
|
||||
def _schedule_connect(self, delay: float) -> None:
|
||||
"""Schedule a connect attempt."""
|
||||
self._cancel_connect()
|
||||
if not delay:
|
||||
self._call_connect_once()
|
||||
return
|
||||
self._connect_timer = self.loop.call_later(delay, self._call_connect_once)
|
||||
|
||||
await self._try_connect()
|
||||
def _call_connect_once(self) -> None:
|
||||
"""Call the connect logic once.
|
||||
|
||||
async def _reconnect_loop(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
await self._reconnect_once()
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.error(
|
||||
"Caught exception while reconnecting to %s",
|
||||
self._log_name,
|
||||
exc_info=True,
|
||||
)
|
||||
Must only be called from _schedule_connect.
|
||||
"""
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._connect_once_or_reschedule(),
|
||||
name=f"{self._log_name}: aioesphomeapi connect",
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the reconnecting logic background task."""
|
||||
# Create reconnection loop outside of HA's tracked tasks in order
|
||||
# not to delay startup.
|
||||
self._loop_task = asyncio.create_task(self._reconnect_loop())
|
||||
def _cancel_connect(self) -> None:
|
||||
"""Cancel the connect."""
|
||||
if self._connect_timer:
|
||||
self._connect_timer.cancel()
|
||||
self._connect_timer = None
|
||||
if self._connect_task:
|
||||
self._connect_task.cancel()
|
||||
self._connect_task = None
|
||||
|
||||
async def _connect_once_or_reschedule(self) -> None:
|
||||
"""Connect once or schedule connect.
|
||||
|
||||
Must only be called from _call_connect_once
|
||||
"""
|
||||
async with self._connected_lock:
|
||||
self._connected = False
|
||||
self._reconnect_event.set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the reconnecting logic background task. Does not disconnect the client."""
|
||||
if self._loop_task is not None:
|
||||
self._loop_task.cancel()
|
||||
self._loop_task = None
|
||||
async with self._wait_task_lock:
|
||||
if self._wait_task is not None:
|
||||
self._wait_task.cancel()
|
||||
self._wait_task = None
|
||||
await self._stop_zc_listen()
|
||||
self._stop_zc_listen()
|
||||
if self._connected or self._is_stopped:
|
||||
return
|
||||
if await self._try_connect():
|
||||
return
|
||||
tries = min(self._tries, 10) # prevent OverflowError
|
||||
wait_time = int(round(min(1.8**tries, 60.0)))
|
||||
if tries == 1:
|
||||
_LOGGER.info(
|
||||
"Trying to connect to %s in the background", self._log_name
|
||||
)
|
||||
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
|
||||
if wait_time:
|
||||
# If we are waiting, start listening for mDNS records
|
||||
self._start_zc_listen()
|
||||
self._schedule_connect(wait_time)
|
||||
|
||||
def stop_callback(self) -> None:
|
||||
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task from the reconnect loop.
|
||||
"""Stop the connect logic."""
|
||||
|
||||
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task from the connect loop.
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._stop_task = None
|
||||
|
||||
self._stop_task = asyncio.create_task(self.stop())
|
||||
self._stop_task = asyncio.create_task(
|
||||
self.stop(),
|
||||
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback",
|
||||
)
|
||||
self._stop_task.add_done_callback(_remove_stop_task)
|
||||
|
||||
async def _start_zc_listen(self) -> None:
|
||||
async def start(self) -> None:
|
||||
"""Start the connecting logic background task."""
|
||||
if self.name:
|
||||
self._cli.set_cached_name_if_unset(self.name)
|
||||
async with self._connected_lock:
|
||||
self._is_stopped = False
|
||||
if self._connected:
|
||||
return
|
||||
self._tries = 0
|
||||
self._schedule_connect(0.0)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the connecting logic background task. Does not disconnect the client."""
|
||||
self._cancel_connect()
|
||||
async with self._connected_lock:
|
||||
self._is_stopped = True
|
||||
# Cancel again while holding the lock
|
||||
self._cancel_connect()
|
||||
self._stop_zc_listen()
|
||||
|
||||
def _start_zc_listen(self) -> None:
|
||||
"""Listen for mDNS records.
|
||||
|
||||
This listener allows us to schedule a reconnect as soon as a
|
||||
This listener allows us to schedule a connect as soon as a
|
||||
received mDNS record indicates the node is up again.
|
||||
"""
|
||||
async with self._zc_lock:
|
||||
if not self._zc_listening:
|
||||
self._zc.async_add_listener(self, None)
|
||||
self._zc_listening = True
|
||||
if not self._zc_listening and self.name:
|
||||
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
|
||||
self._filter_alias = f"{self.name}._esphomelib._tcp.local."
|
||||
self._zc.async_add_listener(self, None)
|
||||
self._zc_listening = True
|
||||
|
||||
async def _stop_zc_listen(self) -> None:
|
||||
def _stop_zc_listen(self) -> None:
|
||||
"""Stop listening for zeroconf updates."""
|
||||
async with self._zc_lock:
|
||||
if self._zc_listening:
|
||||
self._zc.async_remove_listener(self)
|
||||
self._zc_listening = False
|
||||
if self._zc_listening:
|
||||
_LOGGER.debug("Removing zeroconf listener for %s", self.name)
|
||||
self._zc.async_remove_listener(self)
|
||||
self._zc_listening = False
|
||||
|
||||
def async_update_records(
|
||||
self,
|
||||
@ -264,29 +236,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
# Check if already connected, no lock needed for this access and
|
||||
# bail if either the already stopped or we haven't received device info yet
|
||||
if (
|
||||
self._connected
|
||||
or self._reconnect_event.is_set()
|
||||
or self._is_stopped
|
||||
or self.name is None
|
||||
):
|
||||
if self._connected or self._is_stopped or self._filter_alias is None:
|
||||
return
|
||||
|
||||
filter_alias = f"{self.name}._esphomelib._tcp.local."
|
||||
|
||||
for record_update in records:
|
||||
# We only consider PTR records and match using the alias name
|
||||
if (
|
||||
not isinstance(record_update.new, zeroconf.DNSPointer) # type: ignore[attr-defined]
|
||||
or record_update.new.alias != filter_alias
|
||||
or record_update.new.alias != self._filter_alias
|
||||
):
|
||||
continue
|
||||
|
||||
# Tell reconnection logic to retry connection attempt now (even before reconnect timer finishes)
|
||||
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
||||
_LOGGER.debug(
|
||||
"%s: Triggering reconnect because of received mDNS record %s",
|
||||
"%s: Triggering connect because of received mDNS record %s",
|
||||
self._log_name,
|
||||
record_update.new,
|
||||
)
|
||||
self._reconnect_event.set()
|
||||
self._stop_zc_listen()
|
||||
self._schedule_connect(0.0)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user