mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-19 21:11:44 +01:00
Add reconnect logic class (#54)
This commit is contained in:
parent
f4ca46c9d6
commit
ff10a20bce
@ -3,3 +3,4 @@ from .client import APIClient
|
|||||||
from .connection import APIConnection, ConnectionParams
|
from .connection import APIConnection, ConnectionParams
|
||||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||||
from .model import *
|
from .model import *
|
||||||
|
from .reconnect_logic import ReconnectLogic
|
||||||
|
@ -130,6 +130,17 @@ class APIClient:
|
|||||||
zeroconf_instance=zeroconf_instance,
|
zeroconf_instance=zeroconf_instance,
|
||||||
)
|
)
|
||||||
self._connection: Optional[APIConnection] = None
|
self._connection: Optional[APIConnection] = None
|
||||||
|
self._cached_name: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def address(self) -> str:
|
||||||
|
return self._params.address
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _log_name(self) -> str:
|
||||||
|
if self._cached_name is not None:
|
||||||
|
return f"{self._cached_name} @ {self.address}"
|
||||||
|
return self.address
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
self,
|
self,
|
||||||
@ -137,7 +148,7 @@ class APIClient:
|
|||||||
login: bool = False,
|
login: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
raise APIConnectionError("Already connected!")
|
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
||||||
|
|
||||||
connected = False
|
connected = False
|
||||||
stopped = False
|
stopped = False
|
||||||
@ -153,6 +164,7 @@ class APIClient:
|
|||||||
await on_stop()
|
await on_stop()
|
||||||
|
|
||||||
self._connection = APIConnection(self._params, _on_stop)
|
self._connection = APIConnection(self._params, _on_stop)
|
||||||
|
self._connection.log_name = self._log_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._connection.connect()
|
await self._connection.connect()
|
||||||
@ -163,7 +175,9 @@ class APIClient:
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await _on_stop()
|
await _on_stop()
|
||||||
raise APIConnectionError("Unexpected error while connecting: {}".format(e))
|
raise APIConnectionError(
|
||||||
|
f"Unexpected error while connecting to {self._log_name}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
connected = True
|
connected = True
|
||||||
|
|
||||||
@ -174,15 +188,15 @@ class APIClient:
|
|||||||
|
|
||||||
def _check_connected(self) -> None:
|
def _check_connected(self) -> None:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
raise APIConnectionError("Not connected!")
|
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||||
if not self._connection.is_connected:
|
if not self._connection.is_connected:
|
||||||
raise APIConnectionError("Connection not done!")
|
raise APIConnectionError(f"Connection not done for {self._log_name}!")
|
||||||
|
|
||||||
def _check_authenticated(self) -> None:
|
def _check_authenticated(self) -> None:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
if not self._connection.is_authenticated:
|
if not self._connection.is_authenticated:
|
||||||
raise APIConnectionError("Not authenticated!")
|
raise APIConnectionError(f"Not authenticated for {self._log_name}!")
|
||||||
|
|
||||||
async def device_info(self) -> DeviceInfo:
|
async def device_info(self) -> DeviceInfo:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
@ -190,7 +204,10 @@ class APIClient:
|
|||||||
resp = await self._connection.send_message_await_response(
|
resp = await self._connection.send_message_await_response(
|
||||||
DeviceInfoRequest(), DeviceInfoResponse
|
DeviceInfoRequest(), DeviceInfoResponse
|
||||||
)
|
)
|
||||||
return DeviceInfo.from_pb(resp)
|
info = DeviceInfo.from_pb(resp)
|
||||||
|
self._cached_name = info.name
|
||||||
|
self._connection.log_name = self._log_name
|
||||||
|
return info
|
||||||
|
|
||||||
async def list_entities_services(
|
async def list_entities_services(
|
||||||
self,
|
self,
|
||||||
|
@ -56,6 +56,7 @@ class APIConnection:
|
|||||||
self._api_version: Optional[APIVersion] = None
|
self._api_version: Optional[APIVersion] = None
|
||||||
|
|
||||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||||
|
self.log_name = params.address
|
||||||
|
|
||||||
def _start_ping(self) -> None:
|
def _start_ping(self) -> None:
|
||||||
async def func() -> None:
|
async def func() -> None:
|
||||||
@ -68,7 +69,7 @@ class APIConnection:
|
|||||||
try:
|
try:
|
||||||
await self.ping()
|
await self.ping()
|
||||||
except APIConnectionError:
|
except APIConnectionError:
|
||||||
_LOGGER.info("%s: Ping Failed!", self._params.address)
|
_LOGGER.info("%s: Ping Failed!", self.log_name)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -87,7 +88,7 @@ class APIConnection:
|
|||||||
self._socket_connected = False
|
self._socket_connected = False
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._authenticated = False
|
self._authenticated = False
|
||||||
_LOGGER.debug("%s: Closed socket", self._params.address)
|
_LOGGER.debug("%s: Closed socket", self.log_name)
|
||||||
|
|
||||||
async def stop(self, force: bool = False) -> None:
|
async def stop(self, force: bool = False) -> None:
|
||||||
if self._stopped:
|
if self._stopped:
|
||||||
@ -106,9 +107,9 @@ class APIConnection:
|
|||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
if self._stopped:
|
if self._stopped:
|
||||||
raise APIConnectionError("Connection is closed!")
|
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
|
||||||
if self._connected:
|
if self._connected:
|
||||||
raise APIConnectionError("Already connected!")
|
raise APIConnectionError(f"Already connected for {self.log_name}!")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coro = async_resolve_host(
|
coro = async_resolve_host(
|
||||||
@ -123,7 +124,9 @@ class APIConnection:
|
|||||||
raise err
|
raise err
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError("Timeout while resolving IP address")
|
raise APIConnectionError(
|
||||||
|
f"Timeout while resolving IP address for {self.log_name}"
|
||||||
|
)
|
||||||
|
|
||||||
self._socket = socket.socket(
|
self._socket = socket.socket(
|
||||||
family=addr.family, type=addr.type, proto=addr.proto
|
family=addr.family, type=addr.type, proto=addr.proto
|
||||||
@ -133,7 +136,7 @@ class APIConnection:
|
|||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Connecting to %s:%s (%s)",
|
"%s: Connecting to %s:%s (%s)",
|
||||||
self._params.address,
|
self.log_name,
|
||||||
self._params.address,
|
self._params.address,
|
||||||
self._params.port,
|
self._params.port,
|
||||||
addr,
|
addr,
|
||||||
@ -165,7 +168,7 @@ class APIConnection:
|
|||||||
raise err
|
raise err
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||||
self._params.address,
|
self.log_name,
|
||||||
resp.server_info,
|
resp.server_info,
|
||||||
resp.api_version_major,
|
resp.api_version_major,
|
||||||
resp.api_version_minor,
|
resp.api_version_minor,
|
||||||
@ -174,7 +177,7 @@ class APIConnection:
|
|||||||
if self._api_version.major > 2:
|
if self._api_version.major > 2:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"%s: Incompatible version %s! Closing connection",
|
"%s: Incompatible version %s! Closing connection",
|
||||||
self._params.address,
|
self.log_name,
|
||||||
self._api_version.major,
|
self._api_version.major,
|
||||||
)
|
)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
@ -346,7 +349,7 @@ class APIConnection:
|
|||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"%s: Error while reading incoming messages: %s",
|
"%s: Error while reading incoming messages: %s",
|
||||||
self._params.address,
|
self.log_name,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
@ -354,7 +357,7 @@ class APIConnection:
|
|||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"%s: Unexpected error while reading incoming messages: %s",
|
"%s: Unexpected error while reading incoming messages: %s",
|
||||||
self._params.address,
|
self.log_name,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
|
251
aioesphomeapi/reconnect_logic.py
Normal file
251
aioesphomeapi/reconnect_logic.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
|
from zeroconf import ( # type: ignore[attr-defined]
|
||||||
|
DNSPointer,
|
||||||
|
DNSRecord,
|
||||||
|
RecordUpdate,
|
||||||
|
RecordUpdateListener,
|
||||||
|
Zeroconf,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .client import APIClient
|
||||||
|
from .core import APIConnectionError
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||||
|
"""Reconnectiong logic handler for ESPHome config entries.
|
||||||
|
|
||||||
|
Contains two reconnect strategies:
|
||||||
|
- Connect with increasing time between connection attempts.
|
||||||
|
- Listen to zeroconf mDNS records, if any records are found for this device, try reconnecting immediately.
|
||||||
|
|
||||||
|
All methods in this class should be run inside the eventloop unless stated otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
client: APIClient,
|
||||||
|
on_connect: Callable[[], Awaitable[None]],
|
||||||
|
on_disconnect: Callable[[], Awaitable[None]],
|
||||||
|
zeroconf_instance: Zeroconf,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize ReconnectingLogic.
|
||||||
|
|
||||||
|
:param client: initialized :class:`APIClient` to reconnect for
|
||||||
|
:param on_connect: Coroutine Function to call when connected.
|
||||||
|
:param on_disconnect: Coroutine Function to call when disconnected.
|
||||||
|
"""
|
||||||
|
self._cli = client
|
||||||
|
self.name = name
|
||||||
|
self._on_connect_cb = on_connect
|
||||||
|
self._on_disconnect_cb = on_disconnect
|
||||||
|
self._zc = zeroconf_instance
|
||||||
|
# Flag to check if the device is connected
|
||||||
|
self._connected = True
|
||||||
|
self._connected_lock = asyncio.Lock()
|
||||||
|
self._zc_lock = asyncio.Lock()
|
||||||
|
self._zc_listening = False
|
||||||
|
# Event the different strategies use for issuing a reconnect attempt.
|
||||||
|
self._reconnect_event = asyncio.Event()
|
||||||
|
# The task containing the infinite reconnect loop while running
|
||||||
|
self._loop_task: Optional[asyncio.Task[None]] = None
|
||||||
|
# How many reconnect attempts have there been already, used for exponential wait time
|
||||||
|
self._tries = 0
|
||||||
|
self._tries_lock = asyncio.Lock()
|
||||||
|
# Track the wait task to cancel it on shutdown
|
||||||
|
self._wait_task: Optional[asyncio.Task[None]] = None
|
||||||
|
self._wait_task_lock = asyncio.Lock()
|
||||||
|
# Event for tracking when logic should stop
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _is_stopped(self) -> bool:
|
||||||
|
return self._stop_event.is_set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _log_name(self) -> str:
|
||||||
|
if self.name is not None:
|
||||||
|
return f"{self.name} @ {self._cli.address}"
|
||||||
|
return self._cli.address
|
||||||
|
|
||||||
|
async def _on_disconnect(self) -> None:
|
||||||
|
"""Log and issue callbacks when disconnecting."""
|
||||||
|
if self._is_stopped:
|
||||||
|
return
|
||||||
|
# This can happen often depending on WiFi signal strength.
|
||||||
|
# So therefore all these connection warnings are logged
|
||||||
|
# as infos. The "unavailable" logic will still trigger so the
|
||||||
|
# user knows if the device is not connected.
|
||||||
|
_LOGGER.info("Disconnected from ESPHome API for %s", self._log_name)
|
||||||
|
|
||||||
|
# Run disconnect hook
|
||||||
|
await self._on_disconnect_cb()
|
||||||
|
await self._start_zc_listen()
|
||||||
|
|
||||||
|
# Reset tries
|
||||||
|
async with self._tries_lock:
|
||||||
|
self._tries = 0
|
||||||
|
# Connected needs to be reset before the reconnect event (opposite order of check)
|
||||||
|
async with self._connected_lock:
|
||||||
|
self._connected = False
|
||||||
|
self._reconnect_event.set()
|
||||||
|
|
||||||
|
async def _wait_and_start_reconnect(self) -> None:
|
||||||
|
"""Wait for exponentially increasing time to issue next reconnect event."""
|
||||||
|
async with self._tries_lock:
|
||||||
|
tries = self._tries
|
||||||
|
# If not first re-try, wait and print message
|
||||||
|
# Cap wait time at 1 minute. This is because while working on the
|
||||||
|
# device (e.g. soldering stuff), users don't want to have to wait
|
||||||
|
# a long time for their device to show up in HA again (this was
|
||||||
|
# mentioned a lot in early feedback)
|
||||||
|
tries = min(tries, 10) # prevent OverflowError
|
||||||
|
wait_time = int(round(min(1.8 ** tries, 60.0)))
|
||||||
|
if tries == 1:
|
||||||
|
_LOGGER.info("Trying to reconnect to %s in the background", self._log_name)
|
||||||
|
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
async with self._wait_task_lock:
|
||||||
|
self._wait_task = None
|
||||||
|
self._reconnect_event.set()
|
||||||
|
|
||||||
|
async def _try_connect(self) -> None:
|
||||||
|
"""Try connecting to the API client."""
|
||||||
|
async with self._tries_lock:
|
||||||
|
tries = self._tries
|
||||||
|
self._tries += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
||||||
|
except APIConnectionError as error:
|
||||||
|
level = logging.WARNING if tries == 0 else logging.DEBUG
|
||||||
|
_LOGGER.log(
|
||||||
|
level,
|
||||||
|
"Can't connect to ESPHome API for %s: %s",
|
||||||
|
self._log_name,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
await self._start_zc_listen()
|
||||||
|
# Schedule re-connect in event loop in order not to delay HA
|
||||||
|
# startup. First connect is scheduled in tracked tasks.
|
||||||
|
async with self._wait_task_lock:
|
||||||
|
# Allow only one wait task at a time
|
||||||
|
# can happen if mDNS record received while waiting, then use existing wait task
|
||||||
|
if self._wait_task is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._wait_task = asyncio.create_task(self._wait_and_start_reconnect())
|
||||||
|
else:
|
||||||
|
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||||
|
async with self._tries_lock:
|
||||||
|
self._tries = 0
|
||||||
|
async with self._connected_lock:
|
||||||
|
self._connected = True
|
||||||
|
await self._stop_zc_listen()
|
||||||
|
await self._on_connect_cb()
|
||||||
|
|
||||||
|
async def _reconnect_once(self) -> None:
|
||||||
|
# Wait and clear reconnection event
|
||||||
|
await self._reconnect_event.wait()
|
||||||
|
self._reconnect_event.clear()
|
||||||
|
|
||||||
|
# If in connected state, do not try to connect again.
|
||||||
|
async with self._connected_lock:
|
||||||
|
if self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._is_stopped:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._try_connect()
|
||||||
|
|
||||||
|
async def _reconnect_loop(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self._reconnect_once()
|
||||||
|
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||||
|
raise
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.error(
|
||||||
|
"Caught exception while reconnecting to %s",
|
||||||
|
self._log_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the reconnecting logic background task."""
|
||||||
|
# Create reconnection loop outside of HA's tracked tasks in order
|
||||||
|
# not to delay startup.
|
||||||
|
self._loop_task = asyncio.create_task(self._reconnect_loop())
|
||||||
|
|
||||||
|
async with self._connected_lock:
|
||||||
|
self._connected = False
|
||||||
|
self._reconnect_event.set()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the reconnecting logic background task. Does not disconnect the client."""
|
||||||
|
if self._loop_task is not None:
|
||||||
|
self._loop_task.cancel()
|
||||||
|
self._loop_task = None
|
||||||
|
async with self._wait_task_lock:
|
||||||
|
if self._wait_task is not None:
|
||||||
|
self._wait_task.cancel()
|
||||||
|
self._wait_task = None
|
||||||
|
await self._stop_zc_listen()
|
||||||
|
|
||||||
|
def stop_callback(self) -> None:
|
||||||
|
asyncio.get_event_loop().create_task(self.stop())
|
||||||
|
|
||||||
|
async def _start_zc_listen(self) -> None:
|
||||||
|
"""Listen for mDNS records.
|
||||||
|
|
||||||
|
This listener allows us to schedule a reconnect as soon as a
|
||||||
|
received mDNS record indicates the node is up again.
|
||||||
|
"""
|
||||||
|
async with self._zc_lock:
|
||||||
|
if not self._zc_listening:
|
||||||
|
self._zc.async_add_listener(self, None)
|
||||||
|
self._zc_listening = True
|
||||||
|
|
||||||
|
async def _stop_zc_listen(self) -> None:
|
||||||
|
"""Stop listening for zeroconf updates."""
|
||||||
|
async with self._zc_lock:
|
||||||
|
if self._zc_listening:
|
||||||
|
self._zc.async_remove_listener(self)
|
||||||
|
self._zc_listening = False
|
||||||
|
|
||||||
|
def _async_on_record(self, record: DNSRecord) -> None:
|
||||||
|
if not isinstance(record, DNSPointer):
|
||||||
|
# We only consider PTR records and match using the alias name
|
||||||
|
return
|
||||||
|
if self._is_stopped or self.name is None:
|
||||||
|
return
|
||||||
|
filter_alias = f"{self.name}._esphomelib._tcp.local."
|
||||||
|
if record.alias != filter_alias:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This is a mDNS record from the device and could mean it just woke up
|
||||||
|
# Check if already connected, no lock needed for this access
|
||||||
|
if self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tell reconnection logic to retry connection attempt now (even before reconnect timer finishes)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s: Triggering reconnect because of received mDNS record %s",
|
||||||
|
self._log_name,
|
||||||
|
record,
|
||||||
|
)
|
||||||
|
self._reconnect_event.set()
|
||||||
|
|
||||||
|
def async_update_records(
|
||||||
|
self, zc: Zeroconf, now: float, records: List[RecordUpdate]
|
||||||
|
) -> None:
|
||||||
|
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop."""
|
||||||
|
for update in records:
|
||||||
|
self._async_on_record(update.new)
|
Loading…
Reference in New Issue
Block a user