mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
424 lines
16 KiB
Python
424 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections.abc import Awaitable
|
|
from enum import Enum
|
|
from typing import Callable
|
|
|
|
import zeroconf
|
|
from zeroconf.const import _TYPE_A as TYPE_A
|
|
from zeroconf.const import _TYPE_PTR as TYPE_PTR
|
|
|
|
from .client import APIClient
|
|
from .core import (
|
|
APIConnectionError,
|
|
InvalidAuthAPIError,
|
|
InvalidEncryptionKeyAPIError,
|
|
RequiresEncryptionAPIError,
|
|
UnhandledAPIConnectionError,
|
|
)
|
|
from .util import address_is_local, host_is_name_part
|
|
from .zeroconf import ZeroconfInstanceType
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
EXPECTED_DISCONNECT_COOLDOWN = 5.0
|
|
MAXIMUM_BACKOFF_TRIES = 100
|
|
|
|
|
|
class ReconnectLogicState(Enum):
|
|
CONNECTING = 0
|
|
HANDSHAKING = 1
|
|
READY = 2
|
|
DISCONNECTED = 3
|
|
|
|
|
|
NOT_YET_CONNECTED_STATES = {
|
|
ReconnectLogicState.DISCONNECTED,
|
|
ReconnectLogicState.CONNECTING,
|
|
}
|
|
|
|
|
|
AUTH_EXCEPTIONS = (
|
|
RequiresEncryptionAPIError,
|
|
InvalidEncryptionKeyAPIError,
|
|
InvalidAuthAPIError,
|
|
)
|
|
|
|
|
|
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|
"""Reconnectiong logic handler for ESPHome config entries.
|
|
|
|
Contains two reconnect strategies:
|
|
- Connect with increasing time between connection attempts.
|
|
- Listen to zeroconf mDNS records, if any records are found for this device, try reconnecting immediately.
|
|
|
|
All methods in this class should be run inside the eventloop unless stated otherwise.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
client: APIClient,
|
|
on_connect: Callable[[], Awaitable[None]],
|
|
on_disconnect: Callable[[bool], Awaitable[None]],
|
|
zeroconf_instance: ZeroconfInstanceType | None = None,
|
|
name: str | None = None,
|
|
on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
|
|
) -> None:
|
|
"""Initialize ReconnectingLogic.
|
|
|
|
:param client: initialized :class:`APIClient` to reconnect for
|
|
:param on_connect: Coroutine Function to call when connected.
|
|
:param on_disconnect: Coroutine Function to call when disconnected.
|
|
"""
|
|
self.loop = asyncio.get_event_loop()
|
|
self._cli = client
|
|
self.name: str | None = None
|
|
if name:
|
|
self.name = name
|
|
elif host_is_name_part(client.address) or address_is_local(client.address):
|
|
self.name = client.address.partition(".")[0]
|
|
if self.name:
|
|
self._cli.set_cached_name_if_unset(self.name)
|
|
self._on_connect_cb = on_connect
|
|
self._on_disconnect_cb = on_disconnect
|
|
self._on_connect_error_cb = on_connect_error
|
|
self._zeroconf_manager = client.zeroconf_manager
|
|
if zeroconf_instance is not None:
|
|
self._zeroconf_manager.set_instance(zeroconf_instance)
|
|
self._ptr_alias: str | None = None
|
|
self._a_name: str | None = None
|
|
# Flag to check if the device is connected
|
|
self._connection_state = ReconnectLogicState.DISCONNECTED
|
|
self._accept_zeroconf_records: bool = True
|
|
self._connected_lock = asyncio.Lock()
|
|
self._is_stopped = True
|
|
self._zc_listening = False
|
|
# How many connect attempts have there been already, used for exponential wait time
|
|
self._tries = 0
|
|
# Event for tracking when logic should stop
|
|
self._connect_task: asyncio.Task[None] | None = None
|
|
self._connect_timer: asyncio.TimerHandle | None = None
|
|
self._stop_task: asyncio.Task[None] | None = None
|
|
|
|
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
|
"""Log and issue callbacks when disconnecting."""
|
|
# This can happen often depending on WiFi signal strength.
|
|
# So therefore all these connection warnings are logged
|
|
# as infos. The "unavailable" logic will still trigger so the
|
|
# user knows if the device is not connected.
|
|
if expected_disconnect:
|
|
# If we expected the disconnect we need
|
|
# to cooldown before connecting in case the remote
|
|
# is rebooting so we don't establish a connection right
|
|
# before its about to reboot in the event we are too fast.
|
|
disconnect_type = "expected"
|
|
wait = EXPECTED_DISCONNECT_COOLDOWN
|
|
else:
|
|
disconnect_type = "unexpected"
|
|
wait = 0
|
|
|
|
_LOGGER.info(
|
|
"Processing %s disconnect from ESPHome API for %s",
|
|
disconnect_type,
|
|
self._cli.log_name,
|
|
)
|
|
|
|
# Run disconnect hook
|
|
async with self._connected_lock:
|
|
self._async_set_connection_state_while_locked(
|
|
ReconnectLogicState.DISCONNECTED
|
|
)
|
|
await self._on_disconnect_cb(expected_disconnect)
|
|
|
|
if not self._is_stopped:
|
|
self._schedule_connect(wait)
|
|
|
|
def _async_set_connection_state_while_locked(
|
|
self, state: ReconnectLogicState
|
|
) -> None:
|
|
"""Set the connection state while holding the lock."""
|
|
assert self._connected_lock.locked(), "connected_lock must be locked"
|
|
self._async_set_connection_state_without_lock(state)
|
|
|
|
def _async_set_connection_state_without_lock(
|
|
self, state: ReconnectLogicState
|
|
) -> None:
|
|
"""Set the connection state without holding the lock.
|
|
|
|
This should only be used for setting the state to DISCONNECTED
|
|
when the state is CONNECTING.
|
|
"""
|
|
self._connection_state = state
|
|
self._accept_zeroconf_records = state in NOT_YET_CONNECTED_STATES
|
|
|
|
def _async_log_connection_error(self, err: Exception) -> None:
|
|
"""Log connection errors."""
|
|
# UnhandledAPIConnectionError is a special case in client
|
|
# for when the connection raises an exception that is not
|
|
# handled by the client. This is usually a bug in the connection
|
|
# code and should be logged as an error.
|
|
is_handled_exception = not isinstance(
|
|
err, UnhandledAPIConnectionError
|
|
) and isinstance(err, APIConnectionError)
|
|
if not is_handled_exception:
|
|
level = logging.ERROR
|
|
elif self._tries == 0:
|
|
level = logging.WARNING
|
|
else:
|
|
level = logging.DEBUG
|
|
_LOGGER.log(
|
|
level,
|
|
"Can't connect to ESPHome API for %s: %s (%s)",
|
|
self._cli.log_name,
|
|
err,
|
|
type(err).__name__,
|
|
# Print stacktrace if unhandled
|
|
exc_info=not is_handled_exception,
|
|
)
|
|
|
|
async def _try_connect(self) -> bool:
|
|
"""Try connecting to the API client."""
|
|
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
|
|
start_connect_time = time.perf_counter()
|
|
try:
|
|
await self._cli.start_connection(on_stop=self._on_disconnect)
|
|
except Exception as err: # pylint: disable=broad-except
|
|
await self._handle_connection_failure(err)
|
|
return False
|
|
finish_connect_time = time.perf_counter()
|
|
connect_time = finish_connect_time - start_connect_time
|
|
_LOGGER.info(
|
|
"Successfully connected to %s in %0.3fs", self._cli.log_name, connect_time
|
|
)
|
|
self._stop_zc_listen()
|
|
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
|
|
try:
|
|
await self._cli.finish_connection(login=True)
|
|
except Exception as err: # pylint: disable=broad-except
|
|
await self._handle_connection_failure(err)
|
|
return False
|
|
self._tries = 0
|
|
finish_handshake_time = time.perf_counter()
|
|
handshake_time = finish_handshake_time - finish_connect_time
|
|
_LOGGER.info(
|
|
"Successful handshake with %s in %0.3fs", self._cli.log_name, handshake_time
|
|
)
|
|
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
|
|
await self._on_connect_cb()
|
|
return True
|
|
|
|
async def _handle_connection_failure(self, err: Exception) -> None:
|
|
"""Handle a connection failure."""
|
|
self._async_set_connection_state_while_locked(ReconnectLogicState.DISCONNECTED)
|
|
if self._on_connect_error_cb is not None:
|
|
await self._on_connect_error_cb(err)
|
|
self._async_log_connection_error(err)
|
|
if isinstance(err, AUTH_EXCEPTIONS):
|
|
# If we get an encryption or password error,
|
|
# backoff for the maximum amount of time
|
|
self._tries = MAXIMUM_BACKOFF_TRIES
|
|
else:
|
|
self._tries += 1
|
|
|
|
def _schedule_connect(self, delay: float) -> None:
|
|
"""Schedule a connect attempt."""
|
|
if not delay:
|
|
self._call_connect_once()
|
|
return
|
|
_LOGGER.debug("Scheduling new connect attempt in %.2f seconds", delay)
|
|
self._cancel_connect_timer()
|
|
self._connect_timer = self.loop.call_at(
|
|
self.loop.time() + delay, self._call_connect_once
|
|
)
|
|
|
|
def _call_connect_once(self) -> None:
|
|
"""Call the connect logic once.
|
|
|
|
Must only be called from _schedule_connect.
|
|
"""
|
|
if self._connect_task and not self._connect_task.done():
|
|
if self._connection_state != ReconnectLogicState.CONNECTING:
|
|
# Connection state is far enough along that we should
|
|
# not restart the connect task
|
|
_LOGGER.debug(
|
|
"%s: Not cancelling existing connect task as its already %s!",
|
|
self._cli.log_name,
|
|
self._connection_state,
|
|
)
|
|
return
|
|
_LOGGER.debug(
|
|
"%s: Cancelling existing connect task with state %s, to try again now!",
|
|
self._cli.log_name,
|
|
self._connection_state,
|
|
)
|
|
self._cancel_connect_task("Scheduling new connect attempt")
|
|
self._async_set_connection_state_without_lock(
|
|
ReconnectLogicState.DISCONNECTED
|
|
)
|
|
|
|
self._connect_task = asyncio.create_task(
|
|
self._connect_once_or_reschedule(),
|
|
name=f"{self._cli.log_name}: aioesphomeapi connect",
|
|
)
|
|
|
|
def _cancel_connect_timer(self) -> None:
|
|
"""Cancel the connect timer."""
|
|
if self._connect_timer:
|
|
self._connect_timer.cancel()
|
|
self._connect_timer = None
|
|
|
|
def _cancel_connect_task(self, msg: str) -> None:
|
|
"""Cancel the connect task."""
|
|
if self._connect_task:
|
|
self._connect_task.cancel(msg)
|
|
self._connect_task = None
|
|
|
|
def _cancel_connect(self, msg: str) -> None:
|
|
"""Cancel the connect."""
|
|
self._cancel_connect_timer()
|
|
self._cancel_connect_task(msg)
|
|
|
|
async def _connect_once_or_reschedule(self) -> None:
|
|
"""Connect once or schedule connect.
|
|
|
|
Must only be called from _call_connect_once
|
|
"""
|
|
_LOGGER.debug("Trying to connect to %s", self._cli.log_name)
|
|
async with self._connected_lock:
|
|
_LOGGER.debug("Connected lock acquired for %s", self._cli.log_name)
|
|
if (
|
|
self._connection_state != ReconnectLogicState.DISCONNECTED
|
|
or self._is_stopped
|
|
):
|
|
return
|
|
if await self._try_connect():
|
|
return
|
|
tries = min(self._tries, 10) # prevent OverflowError
|
|
wait_time = int(round(min(1.8**tries, 60.0)))
|
|
if tries == 1:
|
|
_LOGGER.info(
|
|
"Trying to connect to %s in the background", self._cli.log_name
|
|
)
|
|
_LOGGER.debug("Retrying %s in %.2f seconds", self._cli.log_name, wait_time)
|
|
if wait_time:
|
|
# If we are waiting, start listening for mDNS records
|
|
self._start_zc_listen()
|
|
self._schedule_connect(wait_time)
|
|
|
|
def _remove_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
|
"""Remove the stop task from the connect loop.
|
|
We need to do this because the asyncio does not hold
|
|
a strong reference to the task, so it can be garbage
|
|
collected unexpectedly.
|
|
"""
|
|
self._stop_task = None
|
|
|
|
def stop_callback(self) -> None:
|
|
"""Stop the connect logic."""
|
|
self._stop_task = asyncio.create_task(
|
|
self.stop(),
|
|
name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
|
|
)
|
|
self._stop_task.add_done_callback(self._remove_stop_task)
|
|
|
|
async def start(self) -> None:
|
|
"""Start the connecting logic background task."""
|
|
async with self._connected_lock:
|
|
self._is_stopped = False
|
|
if self._connection_state != ReconnectLogicState.DISCONNECTED:
|
|
return
|
|
self._tries = 0
|
|
self._schedule_connect(0.0)
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the connecting logic background task. Does not disconnect the client."""
|
|
if self._connection_state in NOT_YET_CONNECTED_STATES:
|
|
# If we are still establishing a connection, we can safely
|
|
# cancel the connect task here, otherwise we need to wait
|
|
# for the connect task to finish so we can gracefully
|
|
# disconnect.
|
|
self._cancel_connect("Stopping")
|
|
|
|
async with self._connected_lock:
|
|
self._is_stopped = True
|
|
# Cancel again while holding the lock
|
|
self._cancel_connect("Stopping")
|
|
self._stop_zc_listen()
|
|
self._async_set_connection_state_while_locked(
|
|
ReconnectLogicState.DISCONNECTED
|
|
)
|
|
|
|
await self._zeroconf_manager.async_close()
|
|
|
|
def _start_zc_listen(self) -> None:
|
|
"""Listen for mDNS records.
|
|
|
|
This listener allows us to schedule a connect as soon as a
|
|
received mDNS record indicates the node is up again.
|
|
"""
|
|
if not self._zc_listening and self.name:
|
|
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
|
|
self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
|
|
self._a_name = f"{self.name}.local."
|
|
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_add_listener(
|
|
self, None
|
|
)
|
|
self._zc_listening = True
|
|
|
|
def _stop_zc_listen(self) -> None:
|
|
"""Stop listening for zeroconf updates."""
|
|
if self._zc_listening:
|
|
_LOGGER.debug("Removing zeroconf listener for %s", self.name)
|
|
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_remove_listener(
|
|
self
|
|
)
|
|
self._zc_listening = False
|
|
|
|
def _connect_from_zeroconf(self) -> None:
|
|
"""Connect from zeroconf."""
|
|
self._stop_zc_listen()
|
|
self._schedule_connect(0.0)
|
|
|
|
def async_update_records(
|
|
self,
|
|
zc: zeroconf.Zeroconf, # pylint: disable=unused-argument
|
|
now: float, # pylint: disable=unused-argument
|
|
records: list[zeroconf.RecordUpdate],
|
|
) -> None:
|
|
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop.
|
|
|
|
This is a mDNS record from the device and could mean it just woke up.
|
|
"""
|
|
# Check if already connected, no lock needed for this access and
|
|
# bail if either the already stopped or we haven't received device info yet
|
|
if not self._accept_zeroconf_records or self._is_stopped:
|
|
return
|
|
|
|
for record_update in records:
|
|
# We only consider PTR records and match using the alias name
|
|
new_record = record_update.new
|
|
if not (
|
|
(new_record.type == TYPE_PTR and new_record.alias == self._ptr_alias) # type: ignore[attr-defined]
|
|
or (new_record.type == TYPE_A and new_record.name == self._a_name)
|
|
):
|
|
continue
|
|
|
|
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
|
_LOGGER.debug(
|
|
"%s: Triggering connect because of received mDNS record %s",
|
|
self._cli.log_name,
|
|
record_update.new,
|
|
)
|
|
#
|
|
# If we scheduled the connect attempt immediately, the listener could fire
|
|
# again before the connect attempt and we cancel and reschedule the connect
|
|
# attempt again.
|
|
#
|
|
self._connect_from_zeroconf()
|
|
self._accept_zeroconf_records = False
|
|
return
|