Add support for passing multiple addresses to the client

If we have multiple IP addresses for the ESPHome device, and
we do not know which one we should connect to, they should
be passed as `addresses` when creating the `APIClient`
This commit is contained in:
J. Nick Koston 2023-12-12 10:08:01 -10:00
parent 4668b1ff54
commit 44aee612a4
No known key found for this signature in database
4 changed files with 41 additions and 25 deletions

View File

@ -220,6 +220,7 @@ class APIClient:
zeroconf_instance: ZeroconfInstanceType | None = None, zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None, noise_psk: str | None = None,
expected_name: str | None = None, expected_name: str | None = None,
addresses: list[str] | None = None,
) -> None: ) -> None:
"""Create a client, this object is shared across sessions. """Create a client, this object is shared across sessions.
@ -235,10 +236,14 @@ class APIClient:
:param expected_name: Require the devices name to match the given expected name. :param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP. IP passed as address but DHCP reassigned IP.
:param addresses: Optional list of IP addresses to connect to which takes
precedence over the address parameter. This is most commonly used when
the device has dual stack IPv4 and IPv6 addresses and you do not know
which one to connect to.
""" """
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams( self._params = ConnectionParams(
address=str(address), addresses=addresses if addresses else [str(address)],
port=port, port=port,
password=password, password=password,
client_info=client_info, client_info=client_info,
@ -274,17 +279,20 @@ class APIClient:
@property @property
def address(self) -> str: def address(self) -> str:
return self._params.address return self._params.addresses[0]
def _set_log_name(self) -> None: def _set_log_name(self) -> None:
"""Set the log name of the device.""" """Set the log name of the device."""
resolved_address: str | None = None ip_address: str | None = None
if self._connection and self._connection.resolved_addr_info: if self._connection:
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address if self._connection.connected_address:
ip_address = self._connection.connected_address
elif self._connection.resolved_addr_info:
ip_address = self._connection.resolved_addr_info[0].sockaddr.address
self.log_name = build_log_name( self.log_name = build_log_name(
self.cached_name, self.cached_name,
self.address, self.address,
resolved_address, ip_address,
) )
if self._connection: if self._connection:
self._connection.set_log_name(self.log_name) self._connection.set_log_name(self.log_name)

View File

@ -74,7 +74,7 @@ cdef object _handle_complex_message
@cython.dataclasses.dataclass @cython.dataclasses.dataclass
cdef class ConnectionParams: cdef class ConnectionParams:
cdef public str address cdef public list addresses
cdef public object port cdef public object port
cdef public object password cdef public object password
cdef public object client_info cdef public object client_info
@ -109,6 +109,7 @@ cdef class APIConnection:
cdef bint _debug_enabled cdef bint _debug_enabled
cdef public str received_name cdef public str received_name
cdef public object resolved_addr_info cdef public object resolved_addr_info
cdef public str connected_address
cpdef void send_message(self, object msg) cpdef void send_message(self, object msg)

View File

@ -107,7 +107,7 @@ _float = float
@dataclass @dataclass
class ConnectionParams: class ConnectionParams:
address: str addresses: list[str]
port: int port: int
password: str | None password: str | None
client_info: str client_info: str
@ -208,6 +208,7 @@ class APIConnection:
"_debug_enabled", "_debug_enabled",
"received_name", "received_name",
"resolved_addr_info", "resolved_addr_info",
"connected_address",
) )
def __init__( def __init__(
@ -230,7 +231,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages # Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = log_name or params.address self.log_name = log_name or params.addresses
# futures currently subscribed to exceptions in the read task # futures currently subscribed to exceptions in the read task
self._read_exception_futures: set[asyncio.Future[None]] = set() self._read_exception_futures: set[asyncio.Future[None]] = set()
@ -252,6 +253,7 @@ class APIConnection:
self._debug_enabled = debug_enabled self._debug_enabled = debug_enabled
self.received_name: str = "" self.received_name: str = ""
self.resolved_addr_info: list[hr.AddrInfo] = [] self.resolved_addr_info: list[hr.AddrInfo] = []
self.connected_address: str | None = None
def set_log_name(self, name: str) -> None: def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection.""" """Set the friendly log name for this connection."""
@ -325,7 +327,7 @@ class APIConnection:
try: try:
async with asyncio_timeout(RESOLVE_TIMEOUT): async with asyncio_timeout(RESOLVE_TIMEOUT):
return await hr.async_resolve_host( return await hr.async_resolve_host(
self._params.address, self._params.addresses,
self._params.port, self._params.port,
self._params.zeroconf_manager, self._params.zeroconf_manager,
) )
@ -340,7 +342,7 @@ class APIConnection:
_LOGGER.debug( _LOGGER.debug(
"%s: Connecting to %s:%s (%s)", "%s: Connecting to %s:%s (%s)",
self.log_name, self.log_name,
self._params.address, self._params.addresses,
self._params.port, self._params.port,
addrs, addrs,
) )
@ -350,7 +352,7 @@ class APIConnection:
addr.family, addr.family,
addr.type, addr.type,
addr.proto, addr.proto,
self._params.address, "",
astuple(addr.sockaddr), astuple(addr.sockaddr),
) )
for addr in addrs for addr in addrs
@ -361,9 +363,11 @@ class APIConnection:
while addr_infos: while addr_infos:
try: try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT): async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
# Devices are likely on the local network so we
# only use a 100ms happy eyeballs delay
sock = await aiohappyeyeballs.start_connection( sock = await aiohappyeyeballs.start_connection(
addr_infos, addr_infos,
happy_eyeballs_delay=0.25, happy_eyeballs_delay=0.1,
interleave=interleave, interleave=interleave,
loop=self._loop, loop=self._loop,
) )
@ -387,12 +391,13 @@ class APIConnection:
# Try to reduce the pressure on esphome device as it measures # Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes. # ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self.connected_address = sock.getpeername()[0]
if self._debug_enabled: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Opened socket to %s:%s (%s)", "%s: Opened socket to %s:%s (%s)",
self.log_name, self.log_name,
self._params.address, self.connected_address,
self._params.port, self._params.port,
addrs, addrs,
) )

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import logging import logging
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
@ -181,14 +180,23 @@ def _async_ip_address_to_addrs(
async def async_resolve_host( async def async_resolve_host(
host: str, hosts: list[str],
port: int, port: int,
zeroconf_manager: ZeroconfManager | None = None, zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]: ) -> list[AddrInfo]:
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
zc_error: Exception | None = None
for host in hosts:
host_is_name = host_is_name_part(host) or address_is_local(host)
if not host_is_name:
try:
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
continue
zc_error = None
if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0] name = host.partition(".")[0]
try: try:
addrs.extend( addrs.extend(
@ -198,13 +206,7 @@ async def async_resolve_host(
) )
except ResolveAPIError as err: except ResolveAPIError as err:
zc_error = err zc_error = err
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
else:
with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
if not addrs: if not addrs:
if zc_error: if zc_error: