Add support for passing multiple addresses to the client

If we have multiple IP addresses for the ESPHome device, and
we do not know which one we should connect to, they should
be passed as `addresses` when creating the `APIClient`
This commit is contained in:
J. Nick Koston 2023-12-12 10:08:01 -10:00
parent 4668b1ff54
commit 44aee612a4
No known key found for this signature in database
4 changed files with 41 additions and 25 deletions

View File

@ -220,6 +220,7 @@ class APIClient:
zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None,
expected_name: str | None = None,
addresses: list[str] | None = None,
) -> None:
"""Create a client, this object is shared across sessions.
@ -235,10 +236,14 @@ class APIClient:
:param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP.
:param addresses: Optional list of IP addresses to connect to which takes
precedence over the address parameter. This is most commonly used when
the device has dual stack IPv4 and IPv6 addresses and you do not know
which one to connect to.
"""
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams(
address=str(address),
addresses=addresses if addresses else [str(address)],
port=port,
password=password,
client_info=client_info,
@ -274,17 +279,20 @@ class APIClient:
@property
def address(self) -> str:
return self._params.address
return self._params.addresses[0]
def _set_log_name(self) -> None:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
ip_address: str | None = None
if self._connection:
if self._connection.connected_address:
ip_address = self._connection.connected_address
elif self._connection.resolved_addr_info:
ip_address = self._connection.resolved_addr_info[0].sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,
resolved_address,
ip_address,
)
if self._connection:
self._connection.set_log_name(self.log_name)

View File

@ -74,7 +74,7 @@ cdef object _handle_complex_message
@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
cdef public list addresses
cdef public object port
cdef public object password
cdef public object client_info
@ -109,6 +109,7 @@ cdef class APIConnection:
cdef bint _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info
cdef public str connected_address
cpdef void send_message(self, object msg)

View File

@ -107,7 +107,7 @@ _float = float
@dataclass
class ConnectionParams:
address: str
addresses: list[str]
port: int
password: str | None
client_info: str
@ -208,6 +208,7 @@ class APIConnection:
"_debug_enabled",
"received_name",
"resolved_addr_info",
"connected_address",
)
def __init__(
@ -230,7 +231,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
self.log_name = log_name or params.addresses
# futures currently subscribed to exceptions in the read task
self._read_exception_futures: set[asyncio.Future[None]] = set()
@ -252,6 +253,7 @@ class APIConnection:
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: list[hr.AddrInfo] = []
self.connected_address: str | None = None
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
@ -325,7 +327,7 @@ class APIConnection:
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
return await hr.async_resolve_host(
self._params.address,
self._params.addresses,
self._params.port,
self._params.zeroconf_manager,
)
@ -340,7 +342,7 @@ class APIConnection:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.addresses,
self._params.port,
addrs,
)
@ -350,7 +352,7 @@ class APIConnection:
addr.family,
addr.type,
addr.proto,
self._params.address,
"",
astuple(addr.sockaddr),
)
for addr in addrs
@ -361,9 +363,11 @@ class APIConnection:
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
# Devices are likely on the local network so we
# only use a 100ms happy eyeballs delay
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
happy_eyeballs_delay=0.1,
interleave=interleave,
loop=self._loop,
)
@ -387,12 +391,13 @@ class APIConnection:
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self.connected_address = sock.getpeername()[0]
if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self.connected_address,
self._params.port,
addrs,
)

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import socket
from dataclasses import dataclass
@ -181,14 +180,23 @@ def _async_ip_address_to_addrs(
async def async_resolve_host(
host: str,
hosts: list[str],
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []
zc_error: Exception | None = None
for host in hosts:
host_is_name = host_is_name_part(host) or address_is_local(host)
if not host_is_name:
try:
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
continue
zc_error = None
if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0]
try:
addrs.extend(
@ -198,13 +206,7 @@ async def async_resolve_host(
)
except ResolveAPIError as err:
zc_error = err
else:
with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
if not addrs:
if zc_error: