mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Add support for passing multiple addresses to the client
If we have multiple IP addresses for the ESPHome device, and we do not know which one we should connect to, they should be passed as `addresses` when creating the `APIClient`
This commit is contained in:
parent
4668b1ff54
commit
44aee612a4
@ -220,6 +220,7 @@ class APIClient:
|
||||
zeroconf_instance: ZeroconfInstanceType | None = None,
|
||||
noise_psk: str | None = None,
|
||||
expected_name: str | None = None,
|
||||
addresses: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Create a client, this object is shared across sessions.
|
||||
|
||||
@ -235,10 +236,14 @@ class APIClient:
|
||||
:param expected_name: Require the devices name to match the given expected name.
|
||||
Can be used to prevent accidentally connecting to a different device if
|
||||
IP passed as address but DHCP reassigned IP.
|
||||
:param addresses: Optional list of IP addresses to connect to which takes
|
||||
precedence over the address parameter. This is most commonly used when
|
||||
the device has dual stack IPv4 and IPv6 addresses and you do not know
|
||||
which one to connect to.
|
||||
"""
|
||||
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
self._params = ConnectionParams(
|
||||
address=str(address),
|
||||
addresses=addresses if addresses else [str(address)],
|
||||
port=port,
|
||||
password=password,
|
||||
client_info=client_info,
|
||||
@ -274,17 +279,20 @@ class APIClient:
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
return self._params.address
|
||||
return self._params.addresses[0]
|
||||
|
||||
def _set_log_name(self) -> None:
|
||||
"""Set the log name of the device."""
|
||||
resolved_address: str | None = None
|
||||
if self._connection and self._connection.resolved_addr_info:
|
||||
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
|
||||
ip_address: str | None = None
|
||||
if self._connection:
|
||||
if self._connection.connected_address:
|
||||
ip_address = self._connection.connected_address
|
||||
elif self._connection.resolved_addr_info:
|
||||
ip_address = self._connection.resolved_addr_info[0].sockaddr.address
|
||||
self.log_name = build_log_name(
|
||||
self.cached_name,
|
||||
self.address,
|
||||
resolved_address,
|
||||
ip_address,
|
||||
)
|
||||
if self._connection:
|
||||
self._connection.set_log_name(self.log_name)
|
||||
|
@ -74,7 +74,7 @@ cdef object _handle_complex_message
|
||||
|
||||
@cython.dataclasses.dataclass
|
||||
cdef class ConnectionParams:
|
||||
cdef public str address
|
||||
cdef public list addresses
|
||||
cdef public object port
|
||||
cdef public object password
|
||||
cdef public object client_info
|
||||
@ -109,6 +109,7 @@ cdef class APIConnection:
|
||||
cdef bint _debug_enabled
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
cdef public str connected_address
|
||||
|
||||
cpdef void send_message(self, object msg)
|
||||
|
||||
|
@ -107,7 +107,7 @@ _float = float
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
address: str
|
||||
addresses: list[str]
|
||||
port: int
|
||||
password: str | None
|
||||
client_info: str
|
||||
@ -208,6 +208,7 @@ class APIConnection:
|
||||
"_debug_enabled",
|
||||
"received_name",
|
||||
"resolved_addr_info",
|
||||
"connected_address",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -230,7 +231,7 @@ class APIConnection:
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
|
||||
# The friendly name to show for this connection in the logs
|
||||
self.log_name = log_name or params.address
|
||||
self.log_name = log_name or params.addresses
|
||||
|
||||
# futures currently subscribed to exceptions in the read task
|
||||
self._read_exception_futures: set[asyncio.Future[None]] = set()
|
||||
@ -252,6 +253,7 @@ class APIConnection:
|
||||
self._debug_enabled = debug_enabled
|
||||
self.received_name: str = ""
|
||||
self.resolved_addr_info: list[hr.AddrInfo] = []
|
||||
self.connected_address: str | None = None
|
||||
|
||||
def set_log_name(self, name: str) -> None:
|
||||
"""Set the friendly log name for this connection."""
|
||||
@ -325,7 +327,7 @@ class APIConnection:
|
||||
try:
|
||||
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
||||
return await hr.async_resolve_host(
|
||||
self._params.address,
|
||||
self._params.addresses,
|
||||
self._params.port,
|
||||
self._params.zeroconf_manager,
|
||||
)
|
||||
@ -340,7 +342,7 @@ class APIConnection:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.addresses,
|
||||
self._params.port,
|
||||
addrs,
|
||||
)
|
||||
@ -350,7 +352,7 @@ class APIConnection:
|
||||
addr.family,
|
||||
addr.type,
|
||||
addr.proto,
|
||||
self._params.address,
|
||||
"",
|
||||
astuple(addr.sockaddr),
|
||||
)
|
||||
for addr in addrs
|
||||
@ -361,9 +363,11 @@ class APIConnection:
|
||||
while addr_infos:
|
||||
try:
|
||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||
# Devices are likely on the local network so we
|
||||
# only use a 100ms happy eyeballs delay
|
||||
sock = await aiohappyeyeballs.start_connection(
|
||||
addr_infos,
|
||||
happy_eyeballs_delay=0.25,
|
||||
happy_eyeballs_delay=0.1,
|
||||
interleave=interleave,
|
||||
loop=self._loop,
|
||||
)
|
||||
@ -387,12 +391,13 @@ class APIConnection:
|
||||
# Try to reduce the pressure on esphome device as it measures
|
||||
# ram in bytes and we measure ram in megabytes.
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||
self.connected_address = sock.getpeername()[0]
|
||||
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self.connected_address,
|
||||
self._params.port,
|
||||
addrs,
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
@ -181,14 +180,23 @@ def _async_ip_address_to_addrs(
|
||||
|
||||
|
||||
async def async_resolve_host(
|
||||
host: str,
|
||||
hosts: list[str],
|
||||
port: int,
|
||||
zeroconf_manager: ZeroconfManager | None = None,
|
||||
) -> list[AddrInfo]:
|
||||
addrs: list[AddrInfo] = []
|
||||
zc_error: Exception | None = None
|
||||
|
||||
for host in hosts:
|
||||
host_is_name = host_is_name_part(host) or address_is_local(host)
|
||||
|
||||
if not host_is_name:
|
||||
try:
|
||||
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||
except ValueError:
|
||||
# Not an IP address
|
||||
continue
|
||||
|
||||
zc_error = None
|
||||
if host_is_name_part(host) or address_is_local(host):
|
||||
name = host.partition(".")[0]
|
||||
try:
|
||||
addrs.extend(
|
||||
@ -198,12 +206,6 @@ async def async_resolve_host(
|
||||
)
|
||||
except ResolveAPIError as err:
|
||||
zc_error = err
|
||||
|
||||
else:
|
||||
with contextlib.suppress(ValueError):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||
|
||||
if not addrs:
|
||||
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||
|
||||
if not addrs:
|
||||
|
Loading…
Reference in New Issue
Block a user