From 44aee612a48c5bd0b2f9f4a1140216dd9bc2b413 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:08:01 -1000 Subject: [PATCH] Add support for passing multiple addresses to the client If we have multiple IP addresses for the ESPHome device, and we do not know which one we should connect to, they should be passed as `addresses` when creating the `APIClient` --- aioesphomeapi/client.py | 20 ++++++++++++++------ aioesphomeapi/connection.pxd | 3 ++- aioesphomeapi/connection.py | 19 ++++++++++++------- aioesphomeapi/host_resolver.py | 24 +++++++++++++----------- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 4691ecf..007bc11 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -220,6 +220,7 @@ class APIClient: zeroconf_instance: ZeroconfInstanceType | None = None, noise_psk: str | None = None, expected_name: str | None = None, + addresses: list[str] | None = None, ) -> None: """Create a client, this object is shared across sessions. @@ -235,10 +236,14 @@ class APIClient: :param expected_name: Require the devices name to match the given expected name. Can be used to prevent accidentally connecting to a different device if IP passed as address but DHCP reassigned IP. + :param addresses: Optional list of IP addresses to connect to which takes + precedence over the address parameter. This is most commonly used when + the device has dual stack IPv4 and IPv6 addresses and you do not know + which one to connect to. """ self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._params = ConnectionParams( - address=str(address), + addresses=addresses if addresses else [str(address)], port=port, password=password, client_info=client_info, @@ -274,17 +279,20 @@ class APIClient: @property def address(self) -> str: - return self._params.address + return self._params.addresses[0] def _set_log_name(self) -> None: """Set the log name of the device.""" - resolved_address: str | None = None - if self._connection and self._connection.resolved_addr_info: - resolved_address = self._connection.resolved_addr_info[0].sockaddr.address + ip_address: str | None = None + if self._connection: + if self._connection.connected_address: + ip_address = self._connection.connected_address + elif self._connection.resolved_addr_info: + ip_address = self._connection.resolved_addr_info[0].sockaddr.address self.log_name = build_log_name( self.cached_name, self.address, - resolved_address, + ip_address, ) if self._connection: self._connection.set_log_name(self.log_name) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 62aa99c..af556ae 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -74,7 +74,7 @@ cdef object _handle_complex_message @cython.dataclasses.dataclass cdef class ConnectionParams: - cdef public str address + cdef public list addresses cdef public object port cdef public object password cdef public object client_info @@ -109,6 +109,7 @@ cdef class APIConnection: cdef bint _debug_enabled cdef public str received_name cdef public object resolved_addr_info + cdef public str connected_address cpdef void send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index b97565c..efd9cf4 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -107,7 +107,7 @@ _float = float @dataclass class ConnectionParams: - address: str + addresses: list[str] port: int password: str | None client_info: str @@ -208,6 +208,7 @@ class APIConnection: "_debug_enabled", "received_name", "resolved_addr_info", + "connected_address", ) def __init__( @@ -230,7 +231,7 @@ class APIConnection: # Message handlers currently subscribed to incoming messages self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} # The friendly name to show for this connection in the logs - self.log_name = log_name or params.address + self.log_name = log_name or params.addresses # futures currently subscribed to exceptions in the read task self._read_exception_futures: set[asyncio.Future[None]] = set() @@ -252,6 +253,7 @@ class APIConnection: self._debug_enabled = debug_enabled self.received_name: str = "" self.resolved_addr_info: list[hr.AddrInfo] = [] + self.connected_address: str | None = None def set_log_name(self, name: str) -> None: """Set the friendly log name for this connection.""" @@ -325,7 +327,7 @@ class APIConnection: try: async with asyncio_timeout(RESOLVE_TIMEOUT): return await hr.async_resolve_host( - self._params.address, + self._params.addresses, self._params.port, self._params.zeroconf_manager, ) @@ -340,7 +342,7 @@ class APIConnection: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, - self._params.address, + self._params.addresses, self._params.port, addrs, ) @@ -350,7 +352,7 @@ class APIConnection: addr.family, addr.type, addr.proto, - self._params.address, + "", astuple(addr.sockaddr), ) for addr in addrs @@ -361,9 +363,11 @@ class APIConnection: while addr_infos: try: async with asyncio_timeout(TCP_CONNECT_TIMEOUT): + # Devices are likely on the local network so we + # only use a 100ms happy eyeballs delay sock = await aiohappyeyeballs.start_connection( addr_infos, - happy_eyeballs_delay=0.25, + happy_eyeballs_delay=0.1, interleave=interleave, loop=self._loop, ) @@ -387,12 +391,13 @@ class APIConnection: # Try to reduce the pressure on esphome device as it measures # ram in bytes and we measure ram in megabytes. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + self.connected_address = sock.getpeername()[0] if self._debug_enabled: _LOGGER.debug( "%s: Opened socket to %s:%s (%s)", self.log_name, - self._params.address, + self.connected_address, self._params.port, addrs, ) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 153f6ef..bc873df 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import logging import socket from dataclasses import dataclass @@ -181,14 +180,23 @@ def _async_ip_address_to_addrs( async def async_resolve_host( - host: str, + hosts: list[str], port: int, zeroconf_manager: ZeroconfManager | None = None, ) -> list[AddrInfo]: addrs: list[AddrInfo] = [] + zc_error: Exception | None = None + + for host in hosts: + host_is_name = host_is_name_part(host) or address_is_local(host) + + if not host_is_name: + try: + addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + except ValueError: + # Not an IP address + continue - zc_error = None - if host_is_name_part(host) or address_is_local(host): name = host.partition(".")[0] try: addrs.extend( @@ -198,13 +206,7 @@ async def async_resolve_host( ) except ResolveAPIError as err: zc_error = err - - else: - with contextlib.suppress(ValueError): - addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) - - if not addrs: - addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) if not addrs: if zc_error: