Add reconnect logic class (#54)

This commit is contained in:
Otto Winter 2021-06-30 17:10:30 +02:00 committed by GitHub
parent f4ca46c9d6
commit ff10a20bce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 288 additions and 16 deletions

View File

@ -3,3 +3,4 @@ from .client import APIClient
from .connection import APIConnection, ConnectionParams from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import * from .model import *
from .reconnect_logic import ReconnectLogic

View File

@ -130,6 +130,17 @@ class APIClient:
zeroconf_instance=zeroconf_instance, zeroconf_instance=zeroconf_instance,
) )
self._connection: Optional[APIConnection] = None self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
@property
def address(self) -> str:
return self._params.address
@property
def _log_name(self) -> str:
if self._cached_name is not None:
return f"{self._cached_name} @ {self.address}"
return self.address
async def connect( async def connect(
self, self,
@ -137,7 +148,7 @@ class APIClient:
login: bool = False, login: bool = False,
) -> None: ) -> None:
if self._connection is not None: if self._connection is not None:
raise APIConnectionError("Already connected!") raise APIConnectionError(f"Already connected to {self._log_name}!")
connected = False connected = False
stopped = False stopped = False
@ -153,6 +164,7 @@ class APIClient:
await on_stop() await on_stop()
self._connection = APIConnection(self._params, _on_stop) self._connection = APIConnection(self._params, _on_stop)
self._connection.log_name = self._log_name
try: try:
await self._connection.connect() await self._connection.connect()
@ -163,7 +175,9 @@ class APIClient:
raise raise
except Exception as e: except Exception as e:
await _on_stop() await _on_stop()
raise APIConnectionError("Unexpected error while connecting: {}".format(e)) raise APIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}"
) from e
connected = True connected = True
@ -174,15 +188,15 @@ class APIClient:
def _check_connected(self) -> None: def _check_connected(self) -> None:
if self._connection is None: if self._connection is None:
raise APIConnectionError("Not connected!") raise APIConnectionError(f"Not connected to {self._log_name}!")
if not self._connection.is_connected: if not self._connection.is_connected:
raise APIConnectionError("Connection not done!") raise APIConnectionError(f"Connection not done for {self._log_name}!")
def _check_authenticated(self) -> None: def _check_authenticated(self) -> None:
self._check_connected() self._check_connected()
assert self._connection is not None assert self._connection is not None
if not self._connection.is_authenticated: if not self._connection.is_authenticated:
raise APIConnectionError("Not authenticated!") raise APIConnectionError(f"Not authenticated for {self._log_name}!")
async def device_info(self) -> DeviceInfo: async def device_info(self) -> DeviceInfo:
self._check_connected() self._check_connected()
@ -190,7 +204,10 @@ class APIClient:
resp = await self._connection.send_message_await_response( resp = await self._connection.send_message_await_response(
DeviceInfoRequest(), DeviceInfoResponse DeviceInfoRequest(), DeviceInfoResponse
) )
return DeviceInfo.from_pb(resp) info = DeviceInfo.from_pb(resp)
self._cached_name = info.name
self._connection.log_name = self._log_name
return info
async def list_entities_services( async def list_entities_services(
self, self,

View File

@ -56,6 +56,7 @@ class APIConnection:
self._api_version: Optional[APIVersion] = None self._api_version: Optional[APIVersion] = None
self._message_handlers: List[Callable[[message.Message], None]] = [] self._message_handlers: List[Callable[[message.Message], None]] = []
self.log_name = params.address
def _start_ping(self) -> None: def _start_ping(self) -> None:
async def func() -> None: async def func() -> None:
@ -68,7 +69,7 @@ class APIConnection:
try: try:
await self.ping() await self.ping()
except APIConnectionError: except APIConnectionError:
_LOGGER.info("%s: Ping Failed!", self._params.address) _LOGGER.info("%s: Ping Failed!", self.log_name)
await self._on_error() await self._on_error()
return return
@ -87,7 +88,7 @@ class APIConnection:
self._socket_connected = False self._socket_connected = False
self._connected = False self._connected = False
self._authenticated = False self._authenticated = False
_LOGGER.debug("%s: Closed socket", self._params.address) _LOGGER.debug("%s: Closed socket", self.log_name)
async def stop(self, force: bool = False) -> None: async def stop(self, force: bool = False) -> None:
if self._stopped: if self._stopped:
@ -106,9 +107,9 @@ class APIConnection:
async def connect(self) -> None: async def connect(self) -> None:
if self._stopped: if self._stopped:
raise APIConnectionError("Connection is closed!") raise APIConnectionError(f"Connection is closed for {self.log_name}!")
if self._connected: if self._connected:
raise APIConnectionError("Already connected!") raise APIConnectionError(f"Already connected for {self.log_name}!")
try: try:
coro = async_resolve_host( coro = async_resolve_host(
@ -123,7 +124,9 @@ class APIConnection:
raise err raise err
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self._on_error() await self._on_error()
raise APIConnectionError("Timeout while resolving IP address") raise APIConnectionError(
f"Timeout while resolving IP address for {self.log_name}"
)
self._socket = socket.socket( self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto family=addr.family, type=addr.type, proto=addr.proto
@ -133,7 +136,7 @@ class APIConnection:
_LOGGER.debug( _LOGGER.debug(
"%s: Connecting to %s:%s (%s)", "%s: Connecting to %s:%s (%s)",
self._params.address, self.log_name,
self._params.address, self._params.address,
self._params.port, self._params.port,
addr, addr,
@ -165,7 +168,7 @@ class APIConnection:
raise err raise err
_LOGGER.debug( _LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)", "%s: Successfully connected ('%s' API=%s.%s)",
self._params.address, self.log_name,
resp.server_info, resp.server_info,
resp.api_version_major, resp.api_version_major,
resp.api_version_minor, resp.api_version_minor,
@ -174,7 +177,7 @@ class APIConnection:
if self._api_version.major > 2: if self._api_version.major > 2:
_LOGGER.error( _LOGGER.error(
"%s: Incompatible version %s! Closing connection", "%s: Incompatible version %s! Closing connection",
self._params.address, self.log_name,
self._api_version.major, self._api_version.major,
) )
await self._on_error() await self._on_error()
@ -346,7 +349,7 @@ class APIConnection:
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.info( _LOGGER.info(
"%s: Error while reading incoming messages: %s", "%s: Error while reading incoming messages: %s",
self._params.address, self.log_name,
err, err,
) )
await self._on_error() await self._on_error()
@ -354,7 +357,7 @@ class APIConnection:
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.info( _LOGGER.info(
"%s: Unexpected error while reading incoming messages: %s", "%s: Unexpected error while reading incoming messages: %s",
self._params.address, self.log_name,
err, err,
) )
await self._on_error() await self._on_error()

View File

@ -0,0 +1,251 @@
import asyncio
import logging
from typing import Awaitable, Callable, List, Optional
from zeroconf import ( # type: ignore[attr-defined]
DNSPointer,
DNSRecord,
RecordUpdate,
RecordUpdateListener,
Zeroconf,
)
from .client import APIClient
from .core import APIConnectionError
_LOGGER = logging.getLogger(__name__)
class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
"""Reconnectiong logic handler for ESPHome config entries.
Contains two reconnect strategies:
- Connect with increasing time between connection attempts.
- Listen to zeroconf mDNS records, if any records are found for this device, try reconnecting immediately.
All methods in this class should be run inside the eventloop unless stated otherwise.
"""
def __init__(
self,
*,
client: APIClient,
on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[], Awaitable[None]],
zeroconf_instance: Zeroconf,
name: Optional[str] = None,
) -> None:
"""Initialize ReconnectingLogic.
:param client: initialized :class:`APIClient` to reconnect for
:param on_connect: Coroutine Function to call when connected.
:param on_disconnect: Coroutine Function to call when disconnected.
"""
self._cli = client
self.name = name
self._on_connect_cb = on_connect
self._on_disconnect_cb = on_disconnect
self._zc = zeroconf_instance
# Flag to check if the device is connected
self._connected = True
self._connected_lock = asyncio.Lock()
self._zc_lock = asyncio.Lock()
self._zc_listening = False
# Event the different strategies use for issuing a reconnect attempt.
self._reconnect_event = asyncio.Event()
# The task containing the infinite reconnect loop while running
self._loop_task: Optional[asyncio.Task[None]] = None
# How many reconnect attempts have there been already, used for exponential wait time
self._tries = 0
self._tries_lock = asyncio.Lock()
# Track the wait task to cancel it on shutdown
self._wait_task: Optional[asyncio.Task[None]] = None
self._wait_task_lock = asyncio.Lock()
# Event for tracking when logic should stop
self._stop_event = asyncio.Event()
@property
def _is_stopped(self) -> bool:
return self._stop_event.is_set()
@property
def _log_name(self) -> str:
if self.name is not None:
return f"{self.name} @ {self._cli.address}"
return self._cli.address
async def _on_disconnect(self) -> None:
"""Log and issue callbacks when disconnecting."""
if self._is_stopped:
return
# This can happen often depending on WiFi signal strength.
# So therefore all these connection warnings are logged
# as infos. The "unavailable" logic will still trigger so the
# user knows if the device is not connected.
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name)
# Run disconnect hook
await self._on_disconnect_cb()
await self._start_zc_listen()
# Reset tries
async with self._tries_lock:
self._tries = 0
# Connected needs to be reset before the reconnect event (opposite order of check)
async with self._connected_lock:
self._connected = False
self._reconnect_event.set()
async def _wait_and_start_reconnect(self) -> None:
"""Wait for exponentially increasing time to issue next reconnect event."""
async with self._tries_lock:
tries = self._tries
# If not first re-try, wait and print message
# Cap wait time at 1 minute. This is because while working on the
# device (e.g. soldering stuff), users don't want to have to wait
# a long time for their device to show up in HA again (this was
# mentioned a lot in early feedback)
tries = min(tries, 10) # prevent OverflowError
wait_time = int(round(min(1.8 ** tries, 60.0)))
if tries == 1:
_LOGGER.info("Trying to reconnect to %s in the background", self._log_name)
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
await asyncio.sleep(wait_time)
async with self._wait_task_lock:
self._wait_task = None
self._reconnect_event.set()
async def _try_connect(self) -> None:
"""Try connecting to the API client."""
async with self._tries_lock:
tries = self._tries
self._tries += 1
try:
await self._cli.connect(on_stop=self._on_disconnect, login=True)
except APIConnectionError as error:
level = logging.WARNING if tries == 0 else logging.DEBUG
_LOGGER.log(
level,
"Can't connect to ESPHome API for %s: %s",
self._log_name,
error,
)
await self._start_zc_listen()
# Schedule re-connect in event loop in order not to delay HA
# startup. First connect is scheduled in tracked tasks.
async with self._wait_task_lock:
# Allow only one wait task at a time
# can happen if mDNS record received while waiting, then use existing wait task
if self._wait_task is not None:
return
self._wait_task = asyncio.create_task(self._wait_and_start_reconnect())
else:
_LOGGER.info("Successfully connected to %s", self._log_name)
async with self._tries_lock:
self._tries = 0
async with self._connected_lock:
self._connected = True
await self._stop_zc_listen()
await self._on_connect_cb()
async def _reconnect_once(self) -> None:
# Wait and clear reconnection event
await self._reconnect_event.wait()
self._reconnect_event.clear()
# If in connected state, do not try to connect again.
async with self._connected_lock:
if self._connected:
return
if self._is_stopped:
return
await self._try_connect()
async def _reconnect_loop(self) -> None:
while True:
try:
await self._reconnect_once()
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception: # pylint: disable=broad-except
_LOGGER.error(
"Caught exception while reconnecting to %s",
self._log_name,
exc_info=True,
)
async def start(self) -> None:
"""Start the reconnecting logic background task."""
# Create reconnection loop outside of HA's tracked tasks in order
# not to delay startup.
self._loop_task = asyncio.create_task(self._reconnect_loop())
async with self._connected_lock:
self._connected = False
self._reconnect_event.set()
async def stop(self) -> None:
"""Stop the reconnecting logic background task. Does not disconnect the client."""
if self._loop_task is not None:
self._loop_task.cancel()
self._loop_task = None
async with self._wait_task_lock:
if self._wait_task is not None:
self._wait_task.cancel()
self._wait_task = None
await self._stop_zc_listen()
def stop_callback(self) -> None:
asyncio.get_event_loop().create_task(self.stop())
async def _start_zc_listen(self) -> None:
"""Listen for mDNS records.
This listener allows us to schedule a reconnect as soon as a
received mDNS record indicates the node is up again.
"""
async with self._zc_lock:
if not self._zc_listening:
self._zc.async_add_listener(self, None)
self._zc_listening = True
async def _stop_zc_listen(self) -> None:
"""Stop listening for zeroconf updates."""
async with self._zc_lock:
if self._zc_listening:
self._zc.async_remove_listener(self)
self._zc_listening = False
def _async_on_record(self, record: DNSRecord) -> None:
if not isinstance(record, DNSPointer):
# We only consider PTR records and match using the alias name
return
if self._is_stopped or self.name is None:
return
filter_alias = f"{self.name}._esphomelib._tcp.local."
if record.alias != filter_alias:
return
# This is a mDNS record from the device and could mean it just woke up
# Check if already connected, no lock needed for this access
if self._connected:
return
# Tell reconnection logic to retry connection attempt now (even before reconnect timer finishes)
_LOGGER.debug(
"%s: Triggering reconnect because of received mDNS record %s",
self._log_name,
record,
)
self._reconnect_event.set()
def async_update_records(
self, zc: Zeroconf, now: float, records: List[RecordUpdate]
) -> None:
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop."""
for update in records:
self._async_on_record(update.new)