Add happy eyeballs support

replaces and closes #233
This commit is contained in:
J. Nick Koston 2023-12-09 17:40:38 -10:00
parent 165331bd41
commit 6340c3c5a4
No known key found for this signature in database
4 changed files with 59 additions and 30 deletions

View File

@ -280,7 +280,7 @@ class APIClient:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info.sockaddr.address
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,

View File

@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable
import aiohappyeyeballs
from google.protobuf import message
import aioesphomeapi.host_resolver as hr
@ -250,7 +251,7 @@ class APIConnection:
self._handshake_complete = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
self.resolved_addr_info: list[hr.AddrInfo] = []
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
@ -319,7 +320,7 @@ class APIConnection:
"""Enable or disable debug logging."""
self._debug_enabled = enable
async def _connect_resolve_host(self) -> hr.AddrInfo:
async def _connect_resolve_host(self) -> list[hr.AddrInfo]:
"""Step 1 in connect process: resolve the address."""
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
@ -333,9 +334,54 @@ class APIConnection:
f"Timeout while resolving IP address for {self.log_name}"
) from err
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
"""Step 2 in connect process: connect the socket."""
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addrs,
)
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
(
addr.family,
addr.type,
addr.proto,
self._params.address,
astuple(addr.sockaddr),
)
for addr in addrs
]
last_exception: Exception | None = None
sock: socket.socket | None = None
interleave = 1
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
interleave=interleave,
loop=self._loop,
)
break
except (OSError, asyncio_TimeoutError) as err:
last_exception = err
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)
if not sock:
if isinstance(last_exception, OSError):
raise SocketAPIError(
f"Error connecting to {addr_infos}: {last_exception}"
) from last_exception
else:
raise SocketAPIError(
f"Timeout while connecting to {addr_infos}"
) from last_exception
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@ -343,31 +389,13 @@ class APIConnection:
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
sockaddr = astuple(addr.sockaddr)
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
await self._loop.sock_connect(sock, sockaddr)
except asyncio_TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
addrs,
)
async def _connect_init_frame_helper(self) -> None:

View File

@ -108,8 +108,10 @@ async def _async_resolve_host_zeroconf(
timeout,
)
addrs: list[AddrInfo] = []
for ip in info.ip_addresses_by_version(IPVersion.All):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
for ip in info.ip_addresses_by_version(IPVersion.V6Only):
addrs.extend(_async_ip_address_to_addrs(ip, port))
for ip in info.ip_addresses_by_version(IPVersion.V4Only):
addrs.extend(_async_ip_address_to_addrs(ip, port))
return addrs
@ -182,7 +184,7 @@ async def async_resolve_host(
host: str,
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> AddrInfo:
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []
zc_error = None
@ -210,6 +212,4 @@ async def async_resolve_host(
raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
# Use first matching result
# Future: return all matches and use first working one
return addrs[0]
return addrs

View File

@ -1,3 +1,4 @@
aiohappyeyeballs>=1.8.1
protobuf>=3.19.0
zeroconf>=0.36.0,<1.0
chacha20poly1305-reuseable>=0.2.5