Add support for passing multiple addresses to the client (#796)

This commit is contained in:
J. Nick Koston 2023-12-12 11:22:14 -10:00 committed by GitHub
parent 4668b1ff54
commit de1d08493d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 120 additions and 92 deletions

View File

@ -220,6 +220,7 @@ class APIClient:
zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None,
expected_name: str | None = None,
addresses: list[str] | None = None,
) -> None:
"""Create a client, this object is shared across sessions.
@ -235,10 +236,14 @@ class APIClient:
:param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP.
:param addresses: Optional list of IP addresses to connect to which takes
precedence over the address parameter. This is most commonly used when
the device has dual stack IPv4 and IPv6 addresses and you do not know
which one to connect to.
"""
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams(
address=str(address),
addresses=addresses if addresses else [str(address)],
port=port,
password=password,
client_info=client_info,
@ -274,17 +279,17 @@ class APIClient:
@property
def address(self) -> str:
return self._params.address
return self._params.addresses[0]
def _set_log_name(self) -> None:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
connected_address: str | None = None
if self._connection and self._connection.connected_address:
connected_address = self._connection.connected_address
self.log_name = build_log_name(
self.cached_name,
self.address,
resolved_address,
self._params.addresses,
connected_address,
)
if self._connection:
self._connection.set_log_name(self.log_name)
@ -328,8 +333,8 @@ class APIClient:
self.log_name,
)
await self._execute_connection_coro(self._connection.start_connection())
# If we resolved the address, we should set the log name now
if self._connection.resolved_addr_info:
# If we connected, we should set the log name now
if self._connection.connected_address:
self._set_log_name()
async def finish_connection(

View File

@ -74,7 +74,7 @@ cdef object _handle_complex_message
@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
cdef public list addresses
cdef public object port
cdef public object password
cdef public object client_info
@ -108,7 +108,7 @@ cdef class APIConnection:
cdef bint _handshake_complete
cdef bint _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info
cdef public str connected_address
cpdef void send_message(self, object msg)

View File

@ -107,7 +107,7 @@ _float = float
@dataclass
class ConnectionParams:
address: str
addresses: list[str]
port: int
password: str | None
client_info: str
@ -207,7 +207,7 @@ class APIConnection:
"_handshake_complete",
"_debug_enabled",
"received_name",
"resolved_addr_info",
"connected_address",
)
def __init__(
@ -230,7 +230,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
self.log_name = log_name or ",".join(params.addresses)
# futures currently subscribed to exceptions in the read task
self._read_exception_futures: set[asyncio.Future[None]] = set()
@ -251,7 +251,7 @@ class APIConnection:
self._handshake_complete = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: list[hr.AddrInfo] = []
self.connected_address: str | None = None
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
@ -325,7 +325,7 @@ class APIConnection:
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
return await hr.async_resolve_host(
self._params.address,
self._params.addresses,
self._params.port,
self._params.zeroconf_manager,
)
@ -338,11 +338,9 @@ class APIConnection:
"""Step 2 in connect process: connect the socket."""
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
"%s: Connecting to %s",
self.log_name,
self._params.address,
self._params.port,
addrs,
", ".join(str(addr.sockaddr) for addr in addrs),
)
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
@ -350,7 +348,7 @@ class APIConnection:
addr.family,
addr.type,
addr.proto,
self._params.address,
"",
astuple(addr.sockaddr),
)
for addr in addrs
@ -361,9 +359,11 @@ class APIConnection:
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
# Devices are likely on the local network so we
# only use a 100ms happy eyeballs delay
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
happy_eyeballs_delay=0.1,
interleave=interleave,
loop=self._loop,
)
@ -387,14 +387,14 @@ class APIConnection:
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self.connected_address = sock.getpeername()[0]
if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
"%s: Opened socket to %s:%s",
self.log_name,
self._params.address,
self.connected_address,
self._params.port,
addrs,
)
async def _connect_init_frame_helper(self) -> None:
@ -567,8 +567,7 @@ class APIConnection:
async def _do_connect(self) -> None:
"""Do the actual connect process."""
self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(self.resolved_addr_info)
await self._connect_socket_connect(await self._connect_resolve_host())
async def start_connection(self) -> None:
"""Start the connection process.

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import socket
from dataclasses import dataclass
@ -181,35 +180,46 @@ def _async_ip_address_to_addrs(
async def async_resolve_host(
host: str,
hosts: list[str],
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []
zc_error: Exception | None = None
zc_error = None
if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
for host in hosts:
host_addrs: list[AddrInfo] = []
host_is_local_name = host_is_name_part(host) or address_is_local(host)
if host_is_local_name:
name = host.partition(".")[0]
try:
host_addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
)
)
)
except ResolveAPIError as err:
zc_error = err
except ResolveAPIError as err:
zc_error = err
else:
with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not host_is_local_name:
try:
host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
pass
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
if not host_addrs:
host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
addrs.extend(host_addrs)
if not addrs:
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
raise ResolveAPIError(
f"Could not resolve host {hosts} - got no results from OS"
)
return addrs

View File

@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool:
return address.removesuffix(".").endswith(".local")
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
def build_log_name(
name: str | None, addresses: list[str], connected_address: str | None
) -> str:
"""Return a log name for a connection."""
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
preferred_address = resolved_address or address
preferred_address = connected_address
for address in addresses:
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
elif not preferred_address:
preferred_address = address
if not preferred_address:
return name or addresses[0]
if (
name
and name != preferred_address

View File

@ -6,7 +6,7 @@ import socket
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
import pytest_asyncio
@ -50,12 +50,6 @@ def resolve_host():
yield func
@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func
@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
@ -71,7 +65,7 @@ def patchable_api_client() -> APIClient:
def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",
addresses=["fake.address"],
port=6052,
password=None,
client_info="Tests client",
@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM)
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
func.return_value = mock_socket
yield func
@ -139,7 +137,6 @@ def _create_mock_transport_protocol(
async def plaintext_connect_task_no_login(
conn: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login(
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
resolve_host, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()

View File

@ -199,7 +199,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
),
],
)
def test_plaintext_frame_helper(
@pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
@ -592,7 +593,9 @@ async def test_noise_frame_helper_bad_encryption(
@pytest.mark.asyncio
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
async def test_init_plaintext_with_wrong_preamble(
conn: APIConnection, aiohappyeyeballs_start_connection
):
loop = asyncio.get_event_loop()
protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection:

View File

@ -4,9 +4,10 @@ import asyncio
import contextlib
import itertools
import logging
import socket
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
import pytest
from google.protobuf import message
@ -169,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
client._connection.api_version = version
def test_expected_name(auth_client: APIClient) -> None:
@pytest.mark.asyncio
async def test_expected_name(auth_client: APIClient) -> None:
"""Ensure expected name can be set externally."""
assert auth_client.expected_name is None
auth_client.expected_name = "awesome"
@ -219,9 +221,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
cli = APIClient("1.2.3.4", 1234, None)
asyncio.get_event_loop()
async def _start_connection_with_delay(*args, **kwargs):
await asyncio.sleep(1)
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
_start_connection_with_delay,
):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
@ -232,8 +240,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
await start_task
assert cli._connection is None
async def _start_connection_without_delay(*args, **kwargs):
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
_start_connection_without_delay,
):
await cli.start_connection()
await asyncio.sleep(0)
@ -894,7 +908,7 @@ async def test_noise_psk_handles_subclassed_string():
)
# Make sure its not a subclassed string
assert type(cli._params.noise_psk) is str
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str
rl = ReconnectLogic(
@ -930,7 +944,7 @@ async def test_no_noise_psk():
)
# Make sure its not a subclassed string
assert cli._params.noise_psk is None
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str
@ -945,7 +959,7 @@ async def test_empty_noise_psk_or_expected_name():
expected_name="",
)
assert cli._params.noise_psk is None
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert cli._params.expected_name is None

View File

@ -221,7 +221,7 @@ async def test_plaintext_connection(
@pytest.mark.asyncio
async def test_start_connection_socket_error(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of socket error during start connection."""
loop = asyncio.get_event_loop()
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
@pytest.mark.asyncio
async def test_start_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of start connection timing out."""
asyncio.get_event_loop()
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
@pytest.mark.asyncio
async def test_start_connection_os_error(
conn: APIConnection, resolve_host, socket_socket
):
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
"""Test handling of start connection has an OSError."""
asyncio.get_event_loop()
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
@pytest.mark.asyncio
async def test_start_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
):
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
"""Test handling of start connection is cancelled."""
asyncio.get_event_loop()
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
@pytest.mark.asyncio
async def test_finish_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of finishing connection being cancelled."""
loop = asyncio.get_event_loop()
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
async def test_plaintext_connection_fails_handshake(
conn: APIConnection,
resolve_host: AsyncMock,
socket_socket: MagicMock,
aiohappyeyeballs_start_connection: MagicMock,
exception_map: tuple[Exception, Exception],
) -> None:
"""Test that the frame helper is closed before the underlying socket.
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
@pytest.mark.asyncio
async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
@ -571,7 +567,8 @@ async def test_connect_resolver_times_out(
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
), pytest.raises(
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
ResolveAPIError,
match="Timeout while resolving IP address for fake.address",
):
await connect(conn, login=False)
@ -581,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()
@ -632,7 +628,6 @@ async def test_disconnect_success_case(
connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()

View File

@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr.async_resolve_host("asdf.local", 6052)
await hr.async_resolve_host(["asdf.local"], 6052)
@pytest.mark.asyncio
@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called()
@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = []
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052)
@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
resolve_addr.return_value = addr_infos
with pytest.raises(ResolveAPIError):
await hr.async_resolve_host("example.local", 6052)
await hr.async_resolve_host(["example.local"], 6052)
@pytest.mark.asyncio
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.com", 6052)
ret = await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = []
with pytest.raises(APIConnectionError):
await hr.async_resolve_host("example.com", 6052)
await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos)
async def test_resolve_host_with_address(resolve_addr, resolve_zc):
resolve_zc.return_value = []
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("127.0.0.1", 6052)
ret = await hr.async_resolve_host(["127.0.0.1"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_not_called()