Add support for passing multiple addresses to the client (#796)

This commit is contained in:
J. Nick Koston 2023-12-12 11:22:14 -10:00 committed by GitHub
parent 4668b1ff54
commit de1d08493d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 120 additions and 92 deletions

View File

@ -220,6 +220,7 @@ class APIClient:
zeroconf_instance: ZeroconfInstanceType | None = None, zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None, noise_psk: str | None = None,
expected_name: str | None = None, expected_name: str | None = None,
addresses: list[str] | None = None,
) -> None: ) -> None:
"""Create a client, this object is shared across sessions. """Create a client, this object is shared across sessions.
@ -235,10 +236,14 @@ class APIClient:
:param expected_name: Require the devices name to match the given expected name. :param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP. IP passed as address but DHCP reassigned IP.
:param addresses: Optional list of IP addresses to connect to which takes
precedence over the address parameter. This is most commonly used when
the device has dual stack IPv4 and IPv6 addresses and you do not know
which one to connect to.
""" """
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams( self._params = ConnectionParams(
address=str(address), addresses=addresses if addresses else [str(address)],
port=port, port=port,
password=password, password=password,
client_info=client_info, client_info=client_info,
@ -274,17 +279,17 @@ class APIClient:
@property @property
def address(self) -> str: def address(self) -> str:
return self._params.address return self._params.addresses[0]
def _set_log_name(self) -> None: def _set_log_name(self) -> None:
"""Set the log name of the device.""" """Set the log name of the device."""
resolved_address: str | None = None connected_address: str | None = None
if self._connection and self._connection.resolved_addr_info: if self._connection and self._connection.connected_address:
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address connected_address = self._connection.connected_address
self.log_name = build_log_name( self.log_name = build_log_name(
self.cached_name, self.cached_name,
self.address, self._params.addresses,
resolved_address, connected_address,
) )
if self._connection: if self._connection:
self._connection.set_log_name(self.log_name) self._connection.set_log_name(self.log_name)
@ -328,8 +333,8 @@ class APIClient:
self.log_name, self.log_name,
) )
await self._execute_connection_coro(self._connection.start_connection()) await self._execute_connection_coro(self._connection.start_connection())
# If we resolved the address, we should set the log name now # If we connected, we should set the log name now
if self._connection.resolved_addr_info: if self._connection.connected_address:
self._set_log_name() self._set_log_name()
async def finish_connection( async def finish_connection(

View File

@ -74,7 +74,7 @@ cdef object _handle_complex_message
@cython.dataclasses.dataclass @cython.dataclasses.dataclass
cdef class ConnectionParams: cdef class ConnectionParams:
cdef public str address cdef public list addresses
cdef public object port cdef public object port
cdef public object password cdef public object password
cdef public object client_info cdef public object client_info
@ -108,7 +108,7 @@ cdef class APIConnection:
cdef bint _handshake_complete cdef bint _handshake_complete
cdef bint _debug_enabled cdef bint _debug_enabled
cdef public str received_name cdef public str received_name
cdef public object resolved_addr_info cdef public str connected_address
cpdef void send_message(self, object msg) cpdef void send_message(self, object msg)

View File

@ -107,7 +107,7 @@ _float = float
@dataclass @dataclass
class ConnectionParams: class ConnectionParams:
address: str addresses: list[str]
port: int port: int
password: str | None password: str | None
client_info: str client_info: str
@ -207,7 +207,7 @@ class APIConnection:
"_handshake_complete", "_handshake_complete",
"_debug_enabled", "_debug_enabled",
"received_name", "received_name",
"resolved_addr_info", "connected_address",
) )
def __init__( def __init__(
@ -230,7 +230,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages # Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = log_name or params.address self.log_name = log_name or ",".join(params.addresses)
# futures currently subscribed to exceptions in the read task # futures currently subscribed to exceptions in the read task
self._read_exception_futures: set[asyncio.Future[None]] = set() self._read_exception_futures: set[asyncio.Future[None]] = set()
@ -251,7 +251,7 @@ class APIConnection:
self._handshake_complete = False self._handshake_complete = False
self._debug_enabled = debug_enabled self._debug_enabled = debug_enabled
self.received_name: str = "" self.received_name: str = ""
self.resolved_addr_info: list[hr.AddrInfo] = [] self.connected_address: str | None = None
def set_log_name(self, name: str) -> None: def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection.""" """Set the friendly log name for this connection."""
@ -325,7 +325,7 @@ class APIConnection:
try: try:
async with asyncio_timeout(RESOLVE_TIMEOUT): async with asyncio_timeout(RESOLVE_TIMEOUT):
return await hr.async_resolve_host( return await hr.async_resolve_host(
self._params.address, self._params.addresses,
self._params.port, self._params.port,
self._params.zeroconf_manager, self._params.zeroconf_manager,
) )
@ -338,11 +338,9 @@ class APIConnection:
"""Step 2 in connect process: connect the socket.""" """Step 2 in connect process: connect the socket."""
if self._debug_enabled: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Connecting to %s:%s (%s)", "%s: Connecting to %s",
self.log_name, self.log_name,
self._params.address, ", ".join(str(addr.sockaddr) for addr in addrs),
self._params.port,
addrs,
) )
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [ addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
@ -350,7 +348,7 @@ class APIConnection:
addr.family, addr.family,
addr.type, addr.type,
addr.proto, addr.proto,
self._params.address, "",
astuple(addr.sockaddr), astuple(addr.sockaddr),
) )
for addr in addrs for addr in addrs
@ -361,9 +359,11 @@ class APIConnection:
while addr_infos: while addr_infos:
try: try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT): async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
# Devices are likely on the local network so we
# only use a 100ms happy eyeballs delay
sock = await aiohappyeyeballs.start_connection( sock = await aiohappyeyeballs.start_connection(
addr_infos, addr_infos,
happy_eyeballs_delay=0.25, happy_eyeballs_delay=0.1,
interleave=interleave, interleave=interleave,
loop=self._loop, loop=self._loop,
) )
@ -387,14 +387,14 @@ class APIConnection:
# Try to reduce the pressure on esphome device as it measures # Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes. # ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self.connected_address = sock.getpeername()[0]
if self._debug_enabled: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Opened socket to %s:%s (%s)", "%s: Opened socket to %s:%s",
self.log_name, self.log_name,
self._params.address, self.connected_address,
self._params.port, self._params.port,
addrs,
) )
async def _connect_init_frame_helper(self) -> None: async def _connect_init_frame_helper(self) -> None:
@ -567,8 +567,7 @@ class APIConnection:
async def _do_connect(self) -> None: async def _do_connect(self) -> None:
"""Do the actual connect process.""" """Do the actual connect process."""
self.resolved_addr_info = await self._connect_resolve_host() await self._connect_socket_connect(await self._connect_resolve_host())
await self._connect_socket_connect(self.resolved_addr_info)
async def start_connection(self) -> None: async def start_connection(self) -> None:
"""Start the connection process. """Start the connection process.

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import logging import logging
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
@ -181,35 +180,46 @@ def _async_ip_address_to_addrs(
async def async_resolve_host( async def async_resolve_host(
host: str, hosts: list[str],
port: int, port: int,
zeroconf_manager: ZeroconfManager | None = None, zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]: ) -> list[AddrInfo]:
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
zc_error: Exception | None = None
zc_error = None for host in hosts:
if host_is_name_part(host) or address_is_local(host): host_addrs: list[AddrInfo] = []
name = host.partition(".")[0] host_is_local_name = host_is_name_part(host) or address_is_local(host)
try:
addrs.extend( if host_is_local_name:
await _async_resolve_host_zeroconf( name = host.partition(".")[0]
name, port, zeroconf_manager=zeroconf_manager try:
host_addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
)
) )
) except ResolveAPIError as err:
except ResolveAPIError as err: zc_error = err
zc_error = err
else: if not host_is_local_name:
with contextlib.suppress(ValueError): try:
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
pass
if not addrs: if not host_addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
addrs.extend(host_addrs)
if not addrs: if not addrs:
if zc_error: if zc_error:
# Only show ZC error if getaddrinfo also didn't work # Only show ZC error if getaddrinfo also didn't work
raise zc_error raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS") raise ResolveAPIError(
f"Could not resolve host {hosts} - got no results from OS"
)
return addrs return addrs

View File

@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool:
return address.removesuffix(".").endswith(".local") return address.removesuffix(".").endswith(".local")
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str: def build_log_name(
name: str | None, addresses: list[str], connected_address: str | None
) -> str:
"""Return a log name for a connection.""" """Return a log name for a connection."""
if not name and address_is_local(address) or host_is_name_part(address): preferred_address = connected_address
name = address.partition(".")[0] for address in addresses:
preferred_address = resolved_address or address if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
elif not preferred_address:
preferred_address = address
if not preferred_address:
return name or addresses[0]
if ( if (
name name
and name != preferred_address and name != preferred_address

View File

@ -6,7 +6,7 @@ import socket
from dataclasses import replace from dataclasses import replace
from functools import partial from functools import partial
from typing import Callable from typing import Callable
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, create_autospec, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -50,12 +50,6 @@ def resolve_host():
yield func yield func
@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func
@pytest.fixture @pytest.fixture
def patchable_api_client() -> APIClient: def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient): class PatchableAPIClient(APIClient):
@ -71,7 +65,7 @@ def patchable_api_client() -> APIClient:
def get_mock_connection_params() -> ConnectionParams: def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams( return ConnectionParams(
address="fake.address", addresses=["fake.address"],
port=6052, port=6052,
password=None, password=None,
client_info="Tests client", client_info="Tests client",
@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
@pytest.fixture() @pytest.fixture()
def aiohappyeyeballs_start_connection(): def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM) mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
func.return_value = mock_socket
yield func yield func
@ -139,7 +137,6 @@ def _create_mock_transport_protocol(
async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login(
conn: APIConnection, conn: APIConnection,
resolve_host, resolve_host,
socket_socket,
event_loop, event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login(
async def plaintext_connect_task_no_login_with_expected_name( async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, conn_with_expected_name: APIConnection,
resolve_host, resolve_host,
socket_socket,
event_loop, event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
async def plaintext_connect_task_with_login( async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, conn_with_password: APIConnection,
resolve_host, resolve_host,
socket_socket,
event_loop, event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client") @pytest_asyncio.fixture(name="api_client")
async def api_client( async def api_client(
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection resolve_host, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()

View File

@ -199,7 +199,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
), ),
], ],
) )
def test_plaintext_frame_helper( @pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None: ) -> None:
for _ in range(3): for _ in range(3):
@ -592,7 +593,9 @@ async def test_noise_frame_helper_bad_encryption(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): async def test_init_plaintext_with_wrong_preamble(
conn: APIConnection, aiohappyeyeballs_start_connection
):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol = get_mock_protocol(conn) protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection: with patch.object(loop, "create_connection") as create_connection:

View File

@ -4,9 +4,10 @@ import asyncio
import contextlib import contextlib
import itertools import itertools
import logging import logging
import socket
from functools import partial from functools import partial
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
import pytest import pytest
from google.protobuf import message from google.protobuf import message
@ -169,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
client._connection.api_version = version client._connection.api_version = version
def test_expected_name(auth_client: APIClient) -> None: @pytest.mark.asyncio
async def test_expected_name(auth_client: APIClient) -> None:
"""Ensure expected name can be set externally.""" """Ensure expected name can be set externally."""
assert auth_client.expected_name is None assert auth_client.expected_name is None
auth_client.expected_name = "awesome" auth_client.expected_name = "awesome"
@ -219,9 +221,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
cli = APIClient("1.2.3.4", 1234, None) cli = APIClient("1.2.3.4", 1234, None)
asyncio.get_event_loop() asyncio.get_event_loop()
async def _start_connection_with_delay(*args, **kwargs):
await asyncio.sleep(1)
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch( with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection", "aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1), _start_connection_with_delay,
): ):
start_task = asyncio.create_task(cli.start_connection()) start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0) await asyncio.sleep(0)
@ -232,8 +240,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
await start_task await start_task
assert cli._connection is None assert cli._connection is None
async def _start_connection_without_delay(*args, **kwargs):
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch( with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection" "aioesphomeapi.connection.aiohappyeyeballs.start_connection",
_start_connection_without_delay,
): ):
await cli.start_connection() await cli.start_connection()
await asyncio.sleep(0) await asyncio.sleep(0)
@ -894,7 +908,7 @@ async def test_noise_psk_handles_subclassed_string():
) )
# Make sure its not a subclassed string # Make sure its not a subclassed string
assert type(cli._params.noise_psk) is str assert type(cli._params.noise_psk) is str
assert type(cli._params.address) is str assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str assert type(cli._params.expected_name) is str
rl = ReconnectLogic( rl = ReconnectLogic(
@ -930,7 +944,7 @@ async def test_no_noise_psk():
) )
# Make sure its not a subclassed string # Make sure its not a subclassed string
assert cli._params.noise_psk is None assert cli._params.noise_psk is None
assert type(cli._params.address) is str assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str assert type(cli._params.expected_name) is str
@ -945,7 +959,7 @@ async def test_empty_noise_psk_or_expected_name():
expected_name="", expected_name="",
) )
assert cli._params.noise_psk is None assert cli._params.noise_psk is None
assert type(cli._params.address) is str assert type(cli._params.addresses[0]) is str
assert cli._params.expected_name is None assert cli._params.expected_name is None

View File

@ -221,7 +221,7 @@ async def test_plaintext_connection(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_connection_socket_error( async def test_start_connection_socket_error(
conn: APIConnection, resolve_host, socket_socket conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
): ):
"""Test handling of socket error during start connection.""" """Test handling of socket error during start connection."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_connection_times_out( async def test_start_connection_times_out(
conn: APIConnection, resolve_host, socket_socket conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
): ):
"""Test handling of start connection timing out.""" """Test handling of start connection timing out."""
asyncio.get_event_loop() asyncio.get_event_loop()
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_connection_os_error( async def test_start_connection_os_error(conn: APIConnection, resolve_host):
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection has an OSError.""" """Test handling of start connection has an OSError."""
asyncio.get_event_loop() asyncio.get_event_loop()
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_connection_is_cancelled( async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection is cancelled.""" """Test handling of start connection is cancelled."""
asyncio.get_event_loop() asyncio.get_event_loop()
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finish_connection_is_cancelled( async def test_finish_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
): ):
"""Test handling of finishing connection being cancelled.""" """Test handling of finishing connection being cancelled."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
async def test_plaintext_connection_fails_handshake( async def test_plaintext_connection_fails_handshake(
conn: APIConnection, conn: APIConnection,
resolve_host: AsyncMock, resolve_host: AsyncMock,
socket_socket: MagicMock, aiohappyeyeballs_start_connection: MagicMock,
exception_map: tuple[Exception, Exception], exception_map: tuple[Exception, Exception],
) -> None: ) -> None:
"""Test that the frame helper is closed before the underlying socket. """Test that the frame helper is closed before the underlying socket.
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_resolver_times_out( async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
@ -571,7 +567,8 @@ async def test_connect_resolver_times_out(
"create_connection", "create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected), side_effect=partial(_create_mock_transport_protocol, transport, connected),
), pytest.raises( ), pytest.raises(
ResolveAPIError, match="Timeout while resolving IP address for fake.address" ResolveAPIError,
match="Timeout while resolving IP address for fake.address",
): ):
await connect(conn, login=False) await connect(conn, login=False)
@ -581,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams, connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop, event_loop: asyncio.AbstractEventLoop,
resolve_host, resolve_host,
socket_socket,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -632,7 +628,6 @@ async def test_disconnect_success_case(
connection_params: ConnectionParams, connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop, event_loop: asyncio.AbstractEventLoop,
resolve_host, resolve_host,
socket_socket,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"), side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"): ), pytest.raises(ResolveAPIError, match="no buffers"):
await hr.async_resolve_host("asdf.local", 6052) await hr.async_resolve_host(["asdf.local"], 6052)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052) ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called() resolve_addr.assert_not_called()
@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = [] resolve_zc.return_value = []
resolve_addr.return_value = addr_infos resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052) ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052) resolve_addr.assert_called_once_with("example.local", 6052)
@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
resolve_addr.return_value = addr_infos resolve_addr.return_value = addr_infos
with pytest.raises(ResolveAPIError): with pytest.raises(ResolveAPIError):
await hr.async_resolve_host("example.local", 6052) await hr.async_resolve_host(["example.local"], 6052)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.com", 6052) ret = await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called() resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052) resolve_addr.assert_called_once_with("example.com", 6052)
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = [] resolve_addr.return_value = []
with pytest.raises(APIConnectionError): with pytest.raises(APIConnectionError):
await hr.async_resolve_host("example.com", 6052) await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called() resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052) resolve_addr.assert_called_once_with("example.com", 6052)
@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos)
async def test_resolve_host_with_address(resolve_addr, resolve_zc): async def test_resolve_host_with_address(resolve_addr, resolve_zc):
resolve_zc.return_value = [] resolve_zc.return_value = []
resolve_addr.return_value = addr_infos resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("127.0.0.1", 6052) ret = await hr.async_resolve_host(["127.0.0.1"], 6052)
resolve_zc.assert_not_called() resolve_zc.assert_not_called()
resolve_addr.assert_not_called() resolve_addr.assert_not_called()