diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 4691ecf..cc699ee 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -220,6 +220,7 @@ class APIClient: zeroconf_instance: ZeroconfInstanceType | None = None, noise_psk: str | None = None, expected_name: str | None = None, + addresses: list[str] | None = None, ) -> None: """Create a client, this object is shared across sessions. @@ -235,10 +236,14 @@ class APIClient: :param expected_name: Require the devices name to match the given expected name. Can be used to prevent accidentally connecting to a different device if IP passed as address but DHCP reassigned IP. + :param addresses: Optional list of IP addresses to connect to which takes + precedence over the address parameter. This is most commonly used when + the device has dual stack IPv4 and IPv6 addresses and you do not know + which one to connect to. """ self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._params = ConnectionParams( - address=str(address), + addresses=addresses if addresses else [str(address)], port=port, password=password, client_info=client_info, @@ -274,17 +279,17 @@ class APIClient: @property def address(self) -> str: - return self._params.address + return self._params.addresses[0] def _set_log_name(self) -> None: """Set the log name of the device.""" - resolved_address: str | None = None - if self._connection and self._connection.resolved_addr_info: - resolved_address = self._connection.resolved_addr_info[0].sockaddr.address + connected_address: str | None = None + if self._connection and self._connection.connected_address: + connected_address = self._connection.connected_address self.log_name = build_log_name( self.cached_name, - self.address, - resolved_address, + self._params.addresses, + connected_address, ) if self._connection: self._connection.set_log_name(self.log_name) @@ -328,8 +333,8 @@ class APIClient: self.log_name, ) await self._execute_connection_coro(self._connection.start_connection()) - # If we resolved the address, we should set the log name now - if self._connection.resolved_addr_info: + # If we connected, we should set the log name now + if self._connection.connected_address: self._set_log_name() async def finish_connection( diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 62aa99c..4c356b5 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -74,7 +74,7 @@ cdef object _handle_complex_message @cython.dataclasses.dataclass cdef class ConnectionParams: - cdef public str address + cdef public list addresses cdef public object port cdef public object password cdef public object client_info @@ -108,7 +108,7 @@ cdef class APIConnection: cdef bint _handshake_complete cdef bint _debug_enabled cdef public str received_name - cdef public object resolved_addr_info + cdef public str connected_address cpdef void send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index b97565c..d371b6f 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -107,7 +107,7 @@ _float = float @dataclass class ConnectionParams: - address: str + addresses: list[str] port: int password: str | None client_info: str @@ -207,7 +207,7 @@ class APIConnection: "_handshake_complete", "_debug_enabled", "received_name", - "resolved_addr_info", + "connected_address", ) def __init__( @@ -230,7 +230,7 @@ class APIConnection: # Message handlers currently subscribed to incoming messages self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} # The friendly name to show for this connection in the logs - self.log_name = log_name or params.address + self.log_name = log_name or ",".join(params.addresses) # futures currently subscribed to exceptions in the read task self._read_exception_futures: set[asyncio.Future[None]] = set() @@ -251,7 +251,7 @@ class APIConnection: self._handshake_complete = False self._debug_enabled = debug_enabled self.received_name: str = "" - self.resolved_addr_info: list[hr.AddrInfo] = [] + self.connected_address: str | None = None def set_log_name(self, name: str) -> None: """Set the friendly log name for this connection.""" @@ -325,7 +325,7 @@ class APIConnection: try: async with asyncio_timeout(RESOLVE_TIMEOUT): return await hr.async_resolve_host( - self._params.address, + self._params.addresses, self._params.port, self._params.zeroconf_manager, ) @@ -338,11 +338,9 @@ class APIConnection: """Step 2 in connect process: connect the socket.""" if self._debug_enabled: _LOGGER.debug( - "%s: Connecting to %s:%s (%s)", + "%s: Connecting to %s", self.log_name, - self._params.address, - self._params.port, - addrs, + ", ".join(str(addr.sockaddr) for addr in addrs), ) addr_infos: list[aiohappyeyeballs.AddrInfoType] = [ @@ -350,7 +348,7 @@ class APIConnection: addr.family, addr.type, addr.proto, - self._params.address, + "", astuple(addr.sockaddr), ) for addr in addrs @@ -361,9 +359,11 @@ class APIConnection: while addr_infos: try: async with asyncio_timeout(TCP_CONNECT_TIMEOUT): + # Devices are likely on the local network so we + # only use a 100ms happy eyeballs delay sock = await aiohappyeyeballs.start_connection( addr_infos, - happy_eyeballs_delay=0.25, + happy_eyeballs_delay=0.1, interleave=interleave, loop=self._loop, ) @@ -387,14 +387,14 @@ class APIConnection: # Try to reduce the pressure on esphome device as it measures # ram in bytes and we measure ram in megabytes. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + self.connected_address = sock.getpeername()[0] if self._debug_enabled: _LOGGER.debug( - "%s: Opened socket to %s:%s (%s)", + "%s: Opened socket to %s:%s", self.log_name, - self._params.address, + self.connected_address, self._params.port, - addrs, ) async def _connect_init_frame_helper(self) -> None: @@ -567,8 +567,7 @@ class APIConnection: async def _do_connect(self) -> None: """Do the actual connect process.""" - self.resolved_addr_info = await self._connect_resolve_host() - await self._connect_socket_connect(self.resolved_addr_info) + await self._connect_socket_connect(await self._connect_resolve_host()) async def start_connection(self) -> None: """Start the connection process. diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 153f6ef..14e34c9 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import logging import socket from dataclasses import dataclass @@ -181,35 +180,46 @@ def _async_ip_address_to_addrs( async def async_resolve_host( - host: str, + hosts: list[str], port: int, zeroconf_manager: ZeroconfManager | None = None, ) -> list[AddrInfo]: addrs: list[AddrInfo] = [] + zc_error: Exception | None = None - zc_error = None - if host_is_name_part(host) or address_is_local(host): - name = host.partition(".")[0] - try: - addrs.extend( - await _async_resolve_host_zeroconf( - name, port, zeroconf_manager=zeroconf_manager + for host in hosts: + host_addrs: list[AddrInfo] = [] + host_is_local_name = host_is_name_part(host) or address_is_local(host) + + if host_is_local_name: + name = host.partition(".")[0] + try: + host_addrs.extend( + await _async_resolve_host_zeroconf( + name, port, zeroconf_manager=zeroconf_manager + ) ) - ) - except ResolveAPIError as err: - zc_error = err + except ResolveAPIError as err: + zc_error = err - else: - with contextlib.suppress(ValueError): - addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + if not host_is_local_name: + try: + host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + except ValueError: + # Not an IP address + pass - if not addrs: - addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + if not host_addrs: + host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + + addrs.extend(host_addrs) if not addrs: if zc_error: # Only show ZC error if getaddrinfo also didn't work raise zc_error - raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS") + raise ResolveAPIError( + f"Could not resolve host {hosts} - got no results from OS" + ) return addrs diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index ae226bb..0ca8dae 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool: return address.removesuffix(".").endswith(".local") -def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str: +def build_log_name( + name: str | None, addresses: list[str], connected_address: str | None +) -> str: """Return a log name for a connection.""" - if not name and address_is_local(address) or host_is_name_part(address): - name = address.partition(".")[0] - preferred_address = resolved_address or address + preferred_address = connected_address + for address in addresses: + if not name and address_is_local(address) or host_is_name_part(address): + name = address.partition(".")[0] + elif not preferred_address: + preferred_address = address + if not preferred_address: + return name or addresses[0] if ( name and name != preferred_address diff --git a/tests/conftest.py b/tests/conftest.py index 35bb264..4e25771 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import socket from dataclasses import replace from functools import partial from typing import Callable -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest import pytest_asyncio @@ -50,12 +50,6 @@ def resolve_host(): yield func -@pytest.fixture -def socket_socket(): - with patch("socket.socket") as func: - yield func - - @pytest.fixture def patchable_api_client() -> APIClient: class PatchableAPIClient(APIClient): @@ -71,7 +65,7 @@ def patchable_api_client() -> APIClient: def get_mock_connection_params() -> ConnectionParams: return ConnectionParams( - address="fake.address", + addresses=["fake.address"], port=6052, password=None, client_info="Tests client", @@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio @pytest.fixture() def aiohappyeyeballs_start_connection(): with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: - func.return_value = MagicMock(type=socket.SOCK_STREAM) + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.type = socket.SOCK_STREAM + mock_socket.fileno.return_value = 1 + mock_socket.getpeername.return_value = ("10.0.0.512", 323) + func.return_value = mock_socket yield func @@ -139,7 +137,6 @@ def _create_mock_transport_protocol( async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login_with_expected_name( conn_with_expected_name: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name( async def plaintext_connect_task_with_login( conn_with_password: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login( @pytest_asyncio.fixture(name="api_client") async def api_client( - resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection + resolve_host, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index f9f2818..640f48c 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -199,7 +199,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper): ), ], ) -def test_plaintext_frame_helper( +@pytest.mark.asyncio +async def test_plaintext_frame_helper( in_bytes: bytes, pkt_data: bytes, pkt_type: int ) -> None: for _ in range(3): @@ -592,7 +593,9 @@ async def test_noise_frame_helper_bad_encryption( @pytest.mark.asyncio -async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): +async def test_init_plaintext_with_wrong_preamble( + conn: APIConnection, aiohappyeyeballs_start_connection +): loop = asyncio.get_event_loop() protocol = get_mock_protocol(conn) with patch.object(loop, "create_connection") as create_connection: diff --git a/tests/test_client.py b/tests/test_client.py index 313559d..70fe312 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,9 +4,10 @@ import asyncio import contextlib import itertools import logging +import socket from functools import partial from typing import Any -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch import pytest from google.protobuf import message @@ -169,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion): client._connection.api_version = version -def test_expected_name(auth_client: APIClient) -> None: +@pytest.mark.asyncio +async def test_expected_name(auth_client: APIClient) -> None: """Ensure expected name can be set externally.""" assert auth_client.expected_name is None auth_client.expected_name = "awesome" @@ -219,9 +221,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: cli = APIClient("1.2.3.4", 1234, None) asyncio.get_event_loop() + async def _start_connection_with_delay(*args, **kwargs): + await asyncio.sleep(1) + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = ("4.3.3.3", 323) + return mock_socket + with patch( "aioesphomeapi.connection.aiohappyeyeballs.start_connection", - side_effect=partial(asyncio.sleep, 1), + _start_connection_with_delay, ): start_task = asyncio.create_task(cli.start_connection()) await asyncio.sleep(0) @@ -232,8 +240,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: await start_task assert cli._connection is None + async def _start_connection_without_delay(*args, **kwargs): + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = ("4.3.3.3", 323) + return mock_socket + with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" + "aioesphomeapi.connection.aiohappyeyeballs.start_connection", + _start_connection_without_delay, ): await cli.start_connection() await asyncio.sleep(0) @@ -894,7 +908,7 @@ async def test_noise_psk_handles_subclassed_string(): ) # Make sure its not a subclassed string assert type(cli._params.noise_psk) is str - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert type(cli._params.expected_name) is str rl = ReconnectLogic( @@ -930,7 +944,7 @@ async def test_no_noise_psk(): ) # Make sure its not a subclassed string assert cli._params.noise_psk is None - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert type(cli._params.expected_name) is str @@ -945,7 +959,7 @@ async def test_empty_noise_psk_or_expected_name(): expected_name="", ) assert cli._params.noise_psk is None - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert cli._params.expected_name is None diff --git a/tests/test_connection.py b/tests/test_connection.py index 73cee88..fd4d255 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -221,7 +221,7 @@ async def test_plaintext_connection( @pytest.mark.asyncio async def test_start_connection_socket_error( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of socket error during start connection.""" loop = asyncio.get_event_loop() @@ -238,7 +238,7 @@ async def test_start_connection_socket_error( @pytest.mark.asyncio async def test_start_connection_times_out( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of start connection timing out.""" asyncio.get_event_loop() @@ -264,9 +264,7 @@ async def test_start_connection_times_out( @pytest.mark.asyncio -async def test_start_connection_os_error( - conn: APIConnection, resolve_host, socket_socket -): +async def test_start_connection_os_error(conn: APIConnection, resolve_host): """Test handling of start connection has an OSError.""" asyncio.get_event_loop() @@ -284,9 +282,7 @@ async def test_start_connection_os_error( @pytest.mark.asyncio -async def test_start_connection_is_cancelled( - conn: APIConnection, resolve_host, socket_socket -): +async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host): """Test handling of start connection is cancelled.""" asyncio.get_event_loop() @@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled( @pytest.mark.asyncio async def test_finish_connection_is_cancelled( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of finishing connection being cancelled.""" loop = asyncio.get_event_loop() @@ -368,7 +364,7 @@ async def test_finish_connection_times_out( async def test_plaintext_connection_fails_handshake( conn: APIConnection, resolve_host: AsyncMock, - socket_socket: MagicMock, + aiohappyeyeballs_start_connection: MagicMock, exception_map: tuple[Exception, Exception], ) -> None: """Test that the frame helper is closed before the underlying socket. @@ -558,7 +554,7 @@ async def test_force_disconnect_fails( @pytest.mark.asyncio async def test_connect_resolver_times_out( - conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection + conn: APIConnection, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() @@ -571,7 +567,8 @@ async def test_connect_resolver_times_out( "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ), pytest.raises( - ResolveAPIError, match="Timeout while resolving IP address for fake.address" + ResolveAPIError, + match="Timeout while resolving IP address for fake.address", ): await connect(conn, login=False) @@ -581,7 +578,6 @@ async def test_disconnect_fails_to_send_response( connection_params: ConnectionParams, event_loop: asyncio.AbstractEventLoop, resolve_host, - socket_socket, aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() @@ -632,7 +628,6 @@ async def test_disconnect_success_case( connection_params: ConnectionParams, event_loop: asyncio.AbstractEventLoop, resolve_host, - socket_socket, aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 322f378..82f1095 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", side_effect=Exception("no buffers"), ), pytest.raises(ResolveAPIError, match="no buffers"): - await hr.async_resolve_host("asdf.local", 6052) + await hr.async_resolve_host(["asdf.local"], 6052) @pytest.mark.asyncio @@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = addr_infos - ret = await hr.async_resolve_host("example.local", 6052) + ret = await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_not_called() @@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = [] resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("example.local", 6052) + ret = await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_called_once_with("example.local", 6052) @@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): resolve_addr.return_value = addr_infos with pytest.raises(ResolveAPIError): - await hr.async_resolve_host("example.local", 6052) + await hr.async_resolve_host(["example.local"], 6052) @pytest.mark.asyncio @@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("example.com", 6052) + ret = await hr.async_resolve_host(["example.com"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) @@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = [] with pytest.raises(APIConnectionError): - await hr.async_resolve_host("example.com", 6052) + await hr.async_resolve_host(["example.com"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) @@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos) async def test_resolve_host_with_address(resolve_addr, resolve_zc): resolve_zc.return_value = [] resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("127.0.0.1", 6052) + ret = await hr.async_resolve_host(["127.0.0.1"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_not_called()