diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 4691ecf..007bc11 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -220,6 +220,7 @@ class APIClient: zeroconf_instance: ZeroconfInstanceType | None = None, noise_psk: str | None = None, expected_name: str | None = None, + addresses: list[str] | None = None, ) -> None: """Create a client, this object is shared across sessions. @@ -235,10 +236,14 @@ class APIClient: :param expected_name: Require the devices name to match the given expected name. Can be used to prevent accidentally connecting to a different device if IP passed as address but DHCP reassigned IP. + :param addresses: Optional list of IP addresses to connect to which takes + precedence over the address parameter. This is most commonly used when + the device has dual stack IPv4 and IPv6 addresses and you do not know + which one to connect to. """ self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._params = ConnectionParams( - address=str(address), + addresses=addresses if addresses else [str(address)], port=port, password=password, client_info=client_info, @@ -274,17 +279,20 @@ class APIClient: @property def address(self) -> str: - return self._params.address + return self._params.addresses[0] def _set_log_name(self) -> None: """Set the log name of the device.""" - resolved_address: str | None = None - if self._connection and self._connection.resolved_addr_info: - resolved_address = self._connection.resolved_addr_info[0].sockaddr.address + ip_address: str | None = None + if self._connection: + if self._connection.connected_address: + ip_address = self._connection.connected_address + elif self._connection.resolved_addr_info: + ip_address = self._connection.resolved_addr_info[0].sockaddr.address self.log_name = build_log_name( self.cached_name, self.address, - resolved_address, + ip_address, ) if self._connection: self._connection.set_log_name(self.log_name) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 62aa99c..af556ae 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -74,7 +74,7 @@ cdef object _handle_complex_message @cython.dataclasses.dataclass cdef class ConnectionParams: - cdef public str address + cdef public list addresses cdef public object port cdef public object password cdef public object client_info @@ -109,6 +109,7 @@ cdef class APIConnection: cdef bint _debug_enabled cdef public str received_name cdef public object resolved_addr_info + cdef public str connected_address cpdef void send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index b97565c..efd9cf4 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -107,7 +107,7 @@ _float = float @dataclass class ConnectionParams: - address: str + addresses: list[str] port: int password: str | None client_info: str @@ -208,6 +208,7 @@ class APIConnection: "_debug_enabled", "received_name", "resolved_addr_info", + "connected_address", ) def __init__( @@ -230,7 +231,7 @@ class APIConnection: # Message handlers currently subscribed to incoming messages self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} # The friendly name to show for this connection in the logs - self.log_name = log_name or params.address + self.log_name = log_name or params.addresses # futures currently subscribed to exceptions in the read task self._read_exception_futures: set[asyncio.Future[None]] = set() @@ -252,6 +253,7 @@ class APIConnection: self._debug_enabled = debug_enabled self.received_name: str = "" self.resolved_addr_info: list[hr.AddrInfo] = [] + self.connected_address: str | None = None def set_log_name(self, name: str) -> None: """Set the friendly log name for this connection.""" @@ -325,7 +327,7 @@ class APIConnection: try: async with asyncio_timeout(RESOLVE_TIMEOUT): return await hr.async_resolve_host( - self._params.address, + self._params.addresses, self._params.port, self._params.zeroconf_manager, ) @@ -340,7 +342,7 @@ class APIConnection: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, - self._params.address, + self._params.addresses, self._params.port, addrs, ) @@ -350,7 +352,7 @@ class APIConnection: addr.family, addr.type, addr.proto, - self._params.address, + "", astuple(addr.sockaddr), ) for addr in addrs @@ -361,9 +363,11 @@ class APIConnection: while addr_infos: try: async with asyncio_timeout(TCP_CONNECT_TIMEOUT): + # Devices are likely on the local network so we + # only use a 100ms happy eyeballs delay sock = await aiohappyeyeballs.start_connection( addr_infos, - happy_eyeballs_delay=0.25, + happy_eyeballs_delay=0.1, interleave=interleave, loop=self._loop, ) @@ -387,12 +391,13 @@ class APIConnection: # Try to reduce the pressure on esphome device as it measures # ram in bytes and we measure ram in megabytes. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + self.connected_address = sock.getpeername()[0] if self._debug_enabled: _LOGGER.debug( "%s: Opened socket to %s:%s (%s)", self.log_name, - self._params.address, + self.connected_address, self._params.port, addrs, ) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 153f6ef..bc873df 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import logging import socket from dataclasses import dataclass @@ -181,14 +180,23 @@ def _async_ip_address_to_addrs( async def async_resolve_host( - host: str, + hosts: list[str], port: int, zeroconf_manager: ZeroconfManager | None = None, ) -> list[AddrInfo]: addrs: list[AddrInfo] = [] + zc_error: Exception | None = None + + for host in hosts: + host_is_name = host_is_name_part(host) or address_is_local(host) + + if not host_is_name: + try: + addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + except ValueError: + # Not an IP address + continue - zc_error = None - if host_is_name_part(host) or address_is_local(host): name = host.partition(".")[0] try: addrs.extend( @@ -198,13 +206,7 @@ async def async_resolve_host( ) except ResolveAPIError as err: zc_error = err - - else: - with contextlib.suppress(ValueError): - addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) - - if not addrs: - addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) if not addrs: if zc_error: