From ff10a20bce2509a7f8fce15e5da40aeb1e1dd4a5 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 30 Jun 2021 17:10:30 +0200 Subject: [PATCH] Add reconnect logic class (#54) --- aioesphomeapi/__init__.py | 1 + aioesphomeapi/client.py | 29 +++- aioesphomeapi/connection.py | 23 +-- aioesphomeapi/reconnect_logic.py | 251 +++++++++++++++++++++++++++++++ 4 files changed, 288 insertions(+), 16 deletions(-) create mode 100644 aioesphomeapi/reconnect_logic.py diff --git a/aioesphomeapi/__init__.py b/aioesphomeapi/__init__.py index 1f27050..92d12b4 100644 --- a/aioesphomeapi/__init__.py +++ b/aioesphomeapi/__init__.py @@ -3,3 +3,4 @@ from .client import APIClient from .connection import APIConnection, ConnectionParams from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .model import * +from .reconnect_logic import ReconnectLogic diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 2f597b9..1584ff0 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -130,6 +130,17 @@ class APIClient: zeroconf_instance=zeroconf_instance, ) self._connection: Optional[APIConnection] = None + self._cached_name: Optional[str] = None + + @property + def address(self) -> str: + return self._params.address + + @property + def _log_name(self) -> str: + if self._cached_name is not None: + return f"{self._cached_name} @ {self.address}" + return self.address async def connect( self, @@ -137,7 +148,7 @@ class APIClient: login: bool = False, ) -> None: if self._connection is not None: - raise APIConnectionError("Already connected!") + raise APIConnectionError(f"Already connected to {self._log_name}!") connected = False stopped = False @@ -153,6 +164,7 @@ class APIClient: await on_stop() self._connection = APIConnection(self._params, _on_stop) + self._connection.log_name = self._log_name try: await self._connection.connect() @@ -163,7 +175,9 @@ class APIClient: raise except Exception as e: await _on_stop() - raise APIConnectionError("Unexpected error while connecting: {}".format(e)) + raise APIConnectionError( + f"Unexpected error while connecting to {self._log_name}: {e}" + ) from e connected = True @@ -174,15 +188,15 @@ class APIClient: def _check_connected(self) -> None: if self._connection is None: - raise APIConnectionError("Not connected!") + raise APIConnectionError(f"Not connected to {self._log_name}!") if not self._connection.is_connected: - raise APIConnectionError("Connection not done!") + raise APIConnectionError(f"Connection not done for {self._log_name}!") def _check_authenticated(self) -> None: self._check_connected() assert self._connection is not None if not self._connection.is_authenticated: - raise APIConnectionError("Not authenticated!") + raise APIConnectionError(f"Not authenticated for {self._log_name}!") async def device_info(self) -> DeviceInfo: self._check_connected() @@ -190,7 +204,10 @@ class APIClient: resp = await self._connection.send_message_await_response( DeviceInfoRequest(), DeviceInfoResponse ) - return DeviceInfo.from_pb(resp) + info = DeviceInfo.from_pb(resp) + self._cached_name = info.name + self._connection.log_name = self._log_name + return info async def list_entities_services( self, diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 5848d61..c3002b7 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -56,6 +56,7 @@ class APIConnection: self._api_version: Optional[APIVersion] = None self._message_handlers: List[Callable[[message.Message], None]] = [] + self.log_name = params.address def _start_ping(self) -> None: async def func() -> None: @@ -68,7 +69,7 @@ class APIConnection: try: await self.ping() except APIConnectionError: - _LOGGER.info("%s: Ping Failed!", self._params.address) + _LOGGER.info("%s: Ping Failed!", self.log_name) await self._on_error() return @@ -87,7 +88,7 @@ class APIConnection: self._socket_connected = False self._connected = False self._authenticated = False - _LOGGER.debug("%s: Closed socket", self._params.address) + _LOGGER.debug("%s: Closed socket", self.log_name) async def stop(self, force: bool = False) -> None: if self._stopped: @@ -106,9 +107,9 @@ class APIConnection: async def connect(self) -> None: if self._stopped: - raise APIConnectionError("Connection is closed!") + raise APIConnectionError(f"Connection is closed for {self.log_name}!") if self._connected: - raise APIConnectionError("Already connected!") + raise APIConnectionError(f"Already connected for {self.log_name}!") try: coro = async_resolve_host( @@ -123,7 +124,9 @@ class APIConnection: raise err except asyncio.TimeoutError: await self._on_error() - raise APIConnectionError("Timeout while resolving IP address") + raise APIConnectionError( + f"Timeout while resolving IP address for {self.log_name}" + ) self._socket = socket.socket( family=addr.family, type=addr.type, proto=addr.proto @@ -133,7 +136,7 @@ class APIConnection: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", - self._params.address, + self.log_name, self._params.address, self._params.port, addr, @@ -165,7 +168,7 @@ class APIConnection: raise err _LOGGER.debug( "%s: Successfully connected ('%s' API=%s.%s)", - self._params.address, + self.log_name, resp.server_info, resp.api_version_major, resp.api_version_minor, @@ -174,7 +177,7 @@ class APIConnection: if self._api_version.major > 2: _LOGGER.error( "%s: Incompatible version %s! Closing connection", - self._params.address, + self.log_name, self._api_version.major, ) await self._on_error() @@ -346,7 +349,7 @@ class APIConnection: except APIConnectionError as err: _LOGGER.info( "%s: Error while reading incoming messages: %s", - self._params.address, + self.log_name, err, ) await self._on_error() @@ -354,7 +357,7 @@ class APIConnection: except Exception as err: # pylint: disable=broad-except _LOGGER.info( "%s: Unexpected error while reading incoming messages: %s", - self._params.address, + self.log_name, err, ) await self._on_error() diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py new file mode 100644 index 0000000..5e4ec33 --- /dev/null +++ b/aioesphomeapi/reconnect_logic.py @@ -0,0 +1,251 @@ +import asyncio +import logging +from typing import Awaitable, Callable, List, Optional + +from zeroconf import ( # type: ignore[attr-defined] + DNSPointer, + DNSRecord, + RecordUpdate, + RecordUpdateListener, + Zeroconf, +) + +from .client import APIClient +from .core import APIConnectionError + +_LOGGER = logging.getLogger(__name__) + + +class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] + """Reconnectiong logic handler for ESPHome config entries. + + Contains two reconnect strategies: + - Connect with increasing time between connection attempts. + - Listen to zeroconf mDNS records, if any records are found for this device, try reconnecting immediately. + + All methods in this class should be run inside the eventloop unless stated otherwise. + """ + + def __init__( + self, + *, + client: APIClient, + on_connect: Callable[[], Awaitable[None]], + on_disconnect: Callable[[], Awaitable[None]], + zeroconf_instance: Zeroconf, + name: Optional[str] = None, + ) -> None: + """Initialize ReconnectingLogic. + + :param client: initialized :class:`APIClient` to reconnect for + :param on_connect: Coroutine Function to call when connected. + :param on_disconnect: Coroutine Function to call when disconnected. + """ + self._cli = client + self.name = name + self._on_connect_cb = on_connect + self._on_disconnect_cb = on_disconnect + self._zc = zeroconf_instance + # Flag to check if the device is connected + self._connected = True + self._connected_lock = asyncio.Lock() + self._zc_lock = asyncio.Lock() + self._zc_listening = False + # Event the different strategies use for issuing a reconnect attempt. + self._reconnect_event = asyncio.Event() + # The task containing the infinite reconnect loop while running + self._loop_task: Optional[asyncio.Task[None]] = None + # How many reconnect attempts have there been already, used for exponential wait time + self._tries = 0 + self._tries_lock = asyncio.Lock() + # Track the wait task to cancel it on shutdown + self._wait_task: Optional[asyncio.Task[None]] = None + self._wait_task_lock = asyncio.Lock() + # Event for tracking when logic should stop + self._stop_event = asyncio.Event() + + @property + def _is_stopped(self) -> bool: + return self._stop_event.is_set() + + @property + def _log_name(self) -> str: + if self.name is not None: + return f"{self.name} @ {self._cli.address}" + return self._cli.address + + async def _on_disconnect(self) -> None: + """Log and issue callbacks when disconnecting.""" + if self._is_stopped: + return + # This can happen often depending on WiFi signal strength. + # So therefore all these connection warnings are logged + # as infos. The "unavailable" logic will still trigger so the + # user knows if the device is not connected. + _LOGGER.info("Disconnected from ESPHome API for %s", self._log_name) + + # Run disconnect hook + await self._on_disconnect_cb() + await self._start_zc_listen() + + # Reset tries + async with self._tries_lock: + self._tries = 0 + # Connected needs to be reset before the reconnect event (opposite order of check) + async with self._connected_lock: + self._connected = False + self._reconnect_event.set() + + async def _wait_and_start_reconnect(self) -> None: + """Wait for exponentially increasing time to issue next reconnect event.""" + async with self._tries_lock: + tries = self._tries + # If not first re-try, wait and print message + # Cap wait time at 1 minute. This is because while working on the + # device (e.g. soldering stuff), users don't want to have to wait + # a long time for their device to show up in HA again (this was + # mentioned a lot in early feedback) + tries = min(tries, 10) # prevent OverflowError + wait_time = int(round(min(1.8 ** tries, 60.0))) + if tries == 1: + _LOGGER.info("Trying to reconnect to %s in the background", self._log_name) + _LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time) + await asyncio.sleep(wait_time) + async with self._wait_task_lock: + self._wait_task = None + self._reconnect_event.set() + + async def _try_connect(self) -> None: + """Try connecting to the API client.""" + async with self._tries_lock: + tries = self._tries + self._tries += 1 + + try: + await self._cli.connect(on_stop=self._on_disconnect, login=True) + except APIConnectionError as error: + level = logging.WARNING if tries == 0 else logging.DEBUG + _LOGGER.log( + level, + "Can't connect to ESPHome API for %s: %s", + self._log_name, + error, + ) + await self._start_zc_listen() + # Schedule re-connect in event loop in order not to delay HA + # startup. First connect is scheduled in tracked tasks. + async with self._wait_task_lock: + # Allow only one wait task at a time + # can happen if mDNS record received while waiting, then use existing wait task + if self._wait_task is not None: + return + + self._wait_task = asyncio.create_task(self._wait_and_start_reconnect()) + else: + _LOGGER.info("Successfully connected to %s", self._log_name) + async with self._tries_lock: + self._tries = 0 + async with self._connected_lock: + self._connected = True + await self._stop_zc_listen() + await self._on_connect_cb() + + async def _reconnect_once(self) -> None: + # Wait and clear reconnection event + await self._reconnect_event.wait() + self._reconnect_event.clear() + + # If in connected state, do not try to connect again. + async with self._connected_lock: + if self._connected: + return + + if self._is_stopped: + return + + await self._try_connect() + + async def _reconnect_loop(self) -> None: + while True: + try: + await self._reconnect_once() + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception: # pylint: disable=broad-except + _LOGGER.error( + "Caught exception while reconnecting to %s", + self._log_name, + exc_info=True, + ) + + async def start(self) -> None: + """Start the reconnecting logic background task.""" + # Create reconnection loop outside of HA's tracked tasks in order + # not to delay startup. + self._loop_task = asyncio.create_task(self._reconnect_loop()) + + async with self._connected_lock: + self._connected = False + self._reconnect_event.set() + + async def stop(self) -> None: + """Stop the reconnecting logic background task. Does not disconnect the client.""" + if self._loop_task is not None: + self._loop_task.cancel() + self._loop_task = None + async with self._wait_task_lock: + if self._wait_task is not None: + self._wait_task.cancel() + self._wait_task = None + await self._stop_zc_listen() + + def stop_callback(self) -> None: + asyncio.get_event_loop().create_task(self.stop()) + + async def _start_zc_listen(self) -> None: + """Listen for mDNS records. + + This listener allows us to schedule a reconnect as soon as a + received mDNS record indicates the node is up again. + """ + async with self._zc_lock: + if not self._zc_listening: + self._zc.async_add_listener(self, None) + self._zc_listening = True + + async def _stop_zc_listen(self) -> None: + """Stop listening for zeroconf updates.""" + async with self._zc_lock: + if self._zc_listening: + self._zc.async_remove_listener(self) + self._zc_listening = False + + def _async_on_record(self, record: DNSRecord) -> None: + if not isinstance(record, DNSPointer): + # We only consider PTR records and match using the alias name + return + if self._is_stopped or self.name is None: + return + filter_alias = f"{self.name}._esphomelib._tcp.local." + if record.alias != filter_alias: + return + + # This is a mDNS record from the device and could mean it just woke up + # Check if already connected, no lock needed for this access + if self._connected: + return + + # Tell reconnection logic to retry connection attempt now (even before reconnect timer finishes) + _LOGGER.debug( + "%s: Triggering reconnect because of received mDNS record %s", + self._log_name, + record, + ) + self._reconnect_event.set() + + def async_update_records( + self, zc: Zeroconf, now: float, records: List[RecordUpdate] + ) -> None: + """Listen to zeroconf updated mDNS records. This must be called from the eventloop.""" + for update in records: + self._async_on_record(update.new)