mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-12 13:29:49 +01:00
Add support for passing multiple addresses to the client (#796)
This commit is contained in:
parent
4668b1ff54
commit
de1d08493d
@ -220,6 +220,7 @@ class APIClient:
|
||||
zeroconf_instance: ZeroconfInstanceType | None = None,
|
||||
noise_psk: str | None = None,
|
||||
expected_name: str | None = None,
|
||||
addresses: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Create a client, this object is shared across sessions.
|
||||
|
||||
@ -235,10 +236,14 @@ class APIClient:
|
||||
:param expected_name: Require the devices name to match the given expected name.
|
||||
Can be used to prevent accidentally connecting to a different device if
|
||||
IP passed as address but DHCP reassigned IP.
|
||||
:param addresses: Optional list of IP addresses to connect to which takes
|
||||
precedence over the address parameter. This is most commonly used when
|
||||
the device has dual stack IPv4 and IPv6 addresses and you do not know
|
||||
which one to connect to.
|
||||
"""
|
||||
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
self._params = ConnectionParams(
|
||||
address=str(address),
|
||||
addresses=addresses if addresses else [str(address)],
|
||||
port=port,
|
||||
password=password,
|
||||
client_info=client_info,
|
||||
@ -274,17 +279,17 @@ class APIClient:
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
return self._params.address
|
||||
return self._params.addresses[0]
|
||||
|
||||
def _set_log_name(self) -> None:
|
||||
"""Set the log name of the device."""
|
||||
resolved_address: str | None = None
|
||||
if self._connection and self._connection.resolved_addr_info:
|
||||
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
|
||||
connected_address: str | None = None
|
||||
if self._connection and self._connection.connected_address:
|
||||
connected_address = self._connection.connected_address
|
||||
self.log_name = build_log_name(
|
||||
self.cached_name,
|
||||
self.address,
|
||||
resolved_address,
|
||||
self._params.addresses,
|
||||
connected_address,
|
||||
)
|
||||
if self._connection:
|
||||
self._connection.set_log_name(self.log_name)
|
||||
@ -328,8 +333,8 @@ class APIClient:
|
||||
self.log_name,
|
||||
)
|
||||
await self._execute_connection_coro(self._connection.start_connection())
|
||||
# If we resolved the address, we should set the log name now
|
||||
if self._connection.resolved_addr_info:
|
||||
# If we connected, we should set the log name now
|
||||
if self._connection.connected_address:
|
||||
self._set_log_name()
|
||||
|
||||
async def finish_connection(
|
||||
|
@ -74,7 +74,7 @@ cdef object _handle_complex_message
|
||||
|
||||
@cython.dataclasses.dataclass
|
||||
cdef class ConnectionParams:
|
||||
cdef public str address
|
||||
cdef public list addresses
|
||||
cdef public object port
|
||||
cdef public object password
|
||||
cdef public object client_info
|
||||
@ -108,7 +108,7 @@ cdef class APIConnection:
|
||||
cdef bint _handshake_complete
|
||||
cdef bint _debug_enabled
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
cdef public str connected_address
|
||||
|
||||
cpdef void send_message(self, object msg)
|
||||
|
||||
|
@ -107,7 +107,7 @@ _float = float
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
address: str
|
||||
addresses: list[str]
|
||||
port: int
|
||||
password: str | None
|
||||
client_info: str
|
||||
@ -207,7 +207,7 @@ class APIConnection:
|
||||
"_handshake_complete",
|
||||
"_debug_enabled",
|
||||
"received_name",
|
||||
"resolved_addr_info",
|
||||
"connected_address",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -230,7 +230,7 @@ class APIConnection:
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
|
||||
# The friendly name to show for this connection in the logs
|
||||
self.log_name = log_name or params.address
|
||||
self.log_name = log_name or ",".join(params.addresses)
|
||||
|
||||
# futures currently subscribed to exceptions in the read task
|
||||
self._read_exception_futures: set[asyncio.Future[None]] = set()
|
||||
@ -251,7 +251,7 @@ class APIConnection:
|
||||
self._handshake_complete = False
|
||||
self._debug_enabled = debug_enabled
|
||||
self.received_name: str = ""
|
||||
self.resolved_addr_info: list[hr.AddrInfo] = []
|
||||
self.connected_address: str | None = None
|
||||
|
||||
def set_log_name(self, name: str) -> None:
|
||||
"""Set the friendly log name for this connection."""
|
||||
@ -325,7 +325,7 @@ class APIConnection:
|
||||
try:
|
||||
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
||||
return await hr.async_resolve_host(
|
||||
self._params.address,
|
||||
self._params.addresses,
|
||||
self._params.port,
|
||||
self._params.zeroconf_manager,
|
||||
)
|
||||
@ -338,11 +338,9 @@ class APIConnection:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
"%s: Connecting to %s",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addrs,
|
||||
", ".join(str(addr.sockaddr) for addr in addrs),
|
||||
)
|
||||
|
||||
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
|
||||
@ -350,7 +348,7 @@ class APIConnection:
|
||||
addr.family,
|
||||
addr.type,
|
||||
addr.proto,
|
||||
self._params.address,
|
||||
"",
|
||||
astuple(addr.sockaddr),
|
||||
)
|
||||
for addr in addrs
|
||||
@ -361,9 +359,11 @@ class APIConnection:
|
||||
while addr_infos:
|
||||
try:
|
||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||
# Devices are likely on the local network so we
|
||||
# only use a 100ms happy eyeballs delay
|
||||
sock = await aiohappyeyeballs.start_connection(
|
||||
addr_infos,
|
||||
happy_eyeballs_delay=0.25,
|
||||
happy_eyeballs_delay=0.1,
|
||||
interleave=interleave,
|
||||
loop=self._loop,
|
||||
)
|
||||
@ -387,14 +387,14 @@ class APIConnection:
|
||||
# Try to reduce the pressure on esphome device as it measures
|
||||
# ram in bytes and we measure ram in megabytes.
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||
self.connected_address = sock.getpeername()[0]
|
||||
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
"%s: Opened socket to %s:%s",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self.connected_address,
|
||||
self._params.port,
|
||||
addrs,
|
||||
)
|
||||
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
@ -567,8 +567,7 @@ class APIConnection:
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Do the actual connect process."""
|
||||
self.resolved_addr_info = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(self.resolved_addr_info)
|
||||
await self._connect_socket_connect(await self._connect_resolve_host())
|
||||
|
||||
async def start_connection(self) -> None:
|
||||
"""Start the connection process.
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
@ -181,35 +180,46 @@ def _async_ip_address_to_addrs(
|
||||
|
||||
|
||||
async def async_resolve_host(
|
||||
host: str,
|
||||
hosts: list[str],
|
||||
port: int,
|
||||
zeroconf_manager: ZeroconfManager | None = None,
|
||||
) -> list[AddrInfo]:
|
||||
addrs: list[AddrInfo] = []
|
||||
zc_error: Exception | None = None
|
||||
|
||||
zc_error = None
|
||||
if host_is_name_part(host) or address_is_local(host):
|
||||
name = host.partition(".")[0]
|
||||
try:
|
||||
addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
name, port, zeroconf_manager=zeroconf_manager
|
||||
for host in hosts:
|
||||
host_addrs: list[AddrInfo] = []
|
||||
host_is_local_name = host_is_name_part(host) or address_is_local(host)
|
||||
|
||||
if host_is_local_name:
|
||||
name = host.partition(".")[0]
|
||||
try:
|
||||
host_addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
name, port, zeroconf_manager=zeroconf_manager
|
||||
)
|
||||
)
|
||||
)
|
||||
except ResolveAPIError as err:
|
||||
zc_error = err
|
||||
except ResolveAPIError as err:
|
||||
zc_error = err
|
||||
|
||||
else:
|
||||
with contextlib.suppress(ValueError):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||
if not host_is_local_name:
|
||||
try:
|
||||
host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||
except ValueError:
|
||||
# Not an IP address
|
||||
pass
|
||||
|
||||
if not addrs:
|
||||
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||
if not host_addrs:
|
||||
host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||
|
||||
addrs.extend(host_addrs)
|
||||
|
||||
if not addrs:
|
||||
if zc_error:
|
||||
# Only show ZC error if getaddrinfo also didn't work
|
||||
raise zc_error
|
||||
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
|
||||
raise ResolveAPIError(
|
||||
f"Could not resolve host {hosts} - got no results from OS"
|
||||
)
|
||||
|
||||
return addrs
|
||||
|
@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool:
|
||||
return address.removesuffix(".").endswith(".local")
|
||||
|
||||
|
||||
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
|
||||
def build_log_name(
|
||||
name: str | None, addresses: list[str], connected_address: str | None
|
||||
) -> str:
|
||||
"""Return a log name for a connection."""
|
||||
if not name and address_is_local(address) or host_is_name_part(address):
|
||||
name = address.partition(".")[0]
|
||||
preferred_address = resolved_address or address
|
||||
preferred_address = connected_address
|
||||
for address in addresses:
|
||||
if not name and address_is_local(address) or host_is_name_part(address):
|
||||
name = address.partition(".")[0]
|
||||
elif not preferred_address:
|
||||
preferred_address = address
|
||||
if not preferred_address:
|
||||
return name or addresses[0]
|
||||
if (
|
||||
name
|
||||
and name != preferred_address
|
||||
|
@ -6,7 +6,7 @@ import socket
|
||||
from dataclasses import replace
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -50,12 +50,6 @@ def resolve_host():
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_socket():
|
||||
with patch("socket.socket") as func:
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patchable_api_client() -> APIClient:
|
||||
class PatchableAPIClient(APIClient):
|
||||
@ -71,7 +65,7 @@ def patchable_api_client() -> APIClient:
|
||||
|
||||
def get_mock_connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
address="fake.address",
|
||||
addresses=["fake.address"],
|
||||
port=6052,
|
||||
password=None,
|
||||
client_info="Tests client",
|
||||
@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
|
||||
@pytest.fixture()
|
||||
def aiohappyeyeballs_start_connection():
|
||||
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
||||
func.return_value = MagicMock(type=socket.SOCK_STREAM)
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.type = socket.SOCK_STREAM
|
||||
mock_socket.fileno.return_value = 1
|
||||
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
|
||||
func.return_value = mock_socket
|
||||
yield func
|
||||
|
||||
|
||||
@ -139,7 +137,6 @@ def _create_mock_transport_protocol(
|
||||
async def plaintext_connect_task_no_login(
|
||||
conn: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login(
|
||||
async def plaintext_connect_task_no_login_with_expected_name(
|
||||
conn_with_expected_name: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
|
||||
async def plaintext_connect_task_with_login(
|
||||
conn_with_password: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login(
|
||||
|
||||
@pytest_asyncio.fixture(name="api_client")
|
||||
async def api_client(
|
||||
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
resolve_host, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
|
@ -199,7 +199,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_plaintext_frame_helper(
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_frame_helper(
|
||||
in_bytes: bytes, pkt_data: bytes, pkt_type: int
|
||||
) -> None:
|
||||
for _ in range(3):
|
||||
@ -592,7 +593,9 @@ async def test_noise_frame_helper_bad_encryption(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
||||
async def test_init_plaintext_with_wrong_preamble(
|
||||
conn: APIConnection, aiohappyeyeballs_start_connection
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = get_mock_protocol(conn)
|
||||
with patch.object(loop, "create_connection") as create_connection:
|
||||
|
@ -4,9 +4,10 @@ import asyncio
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
import socket
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
@ -169,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
|
||||
client._connection.api_version = version
|
||||
|
||||
|
||||
def test_expected_name(auth_client: APIClient) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expected_name(auth_client: APIClient) -> None:
|
||||
"""Ensure expected name can be set externally."""
|
||||
assert auth_client.expected_name is None
|
||||
auth_client.expected_name = "awesome"
|
||||
@ -219,9 +221,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
||||
cli = APIClient("1.2.3.4", 1234, None)
|
||||
asyncio.get_event_loop()
|
||||
|
||||
async def _start_connection_with_delay(*args, **kwargs):
|
||||
await asyncio.sleep(1)
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||
return mock_socket
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=partial(asyncio.sleep, 1),
|
||||
_start_connection_with_delay,
|
||||
):
|
||||
start_task = asyncio.create_task(cli.start_connection())
|
||||
await asyncio.sleep(0)
|
||||
@ -232,8 +240,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
||||
await start_task
|
||||
assert cli._connection is None
|
||||
|
||||
async def _start_connection_without_delay(*args, **kwargs):
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||
return mock_socket
|
||||
|
||||
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
_start_connection_without_delay,
|
||||
):
|
||||
await cli.start_connection()
|
||||
await asyncio.sleep(0)
|
||||
@ -894,7 +908,7 @@ async def test_noise_psk_handles_subclassed_string():
|
||||
)
|
||||
# Make sure its not a subclassed string
|
||||
assert type(cli._params.noise_psk) is str
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert type(cli._params.expected_name) is str
|
||||
|
||||
rl = ReconnectLogic(
|
||||
@ -930,7 +944,7 @@ async def test_no_noise_psk():
|
||||
)
|
||||
# Make sure its not a subclassed string
|
||||
assert cli._params.noise_psk is None
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert type(cli._params.expected_name) is str
|
||||
|
||||
|
||||
@ -945,7 +959,7 @@ async def test_empty_noise_psk_or_expected_name():
|
||||
expected_name="",
|
||||
)
|
||||
assert cli._params.noise_psk is None
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert cli._params.expected_name is None
|
||||
|
||||
|
||||
|
@ -221,7 +221,7 @@ async def test_plaintext_connection(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_socket_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of socket error during start connection."""
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_times_out(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of start connection timing out."""
|
||||
asyncio.get_event_loop()
|
||||
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_os_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
|
||||
"""Test handling of start connection has an OSError."""
|
||||
asyncio.get_event_loop()
|
||||
|
||||
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
|
||||
"""Test handling of start connection is cancelled."""
|
||||
asyncio.get_event_loop()
|
||||
|
||||
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of finishing connection being cancelled."""
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
|
||||
async def test_plaintext_connection_fails_handshake(
|
||||
conn: APIConnection,
|
||||
resolve_host: AsyncMock,
|
||||
socket_socket: MagicMock,
|
||||
aiohappyeyeballs_start_connection: MagicMock,
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
"""Test that the frame helper is closed before the underlying socket.
|
||||
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resolver_times_out(
|
||||
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
@ -571,7 +567,8 @@ async def test_connect_resolver_times_out(
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
), pytest.raises(
|
||||
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
|
||||
ResolveAPIError,
|
||||
match="Timeout while resolving IP address for fake.address",
|
||||
):
|
||||
await connect(conn, login=False)
|
||||
|
||||
@ -581,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
|
||||
connection_params: ConnectionParams,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -632,7 +628,6 @@ async def test_disconnect_success_case(
|
||||
connection_params: ConnectionParams,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||
side_effect=Exception("no buffers"),
|
||||
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||
await hr.async_resolve_host("asdf.local", 6052)
|
||||
await hr.async_resolve_host(["asdf.local"], 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop):
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_zc.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("example.local", 6052)
|
||||
ret = await hr.async_resolve_host(["example.local"], 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_not_called()
|
||||
@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_zc.return_value = []
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("example.local", 6052)
|
||||
ret = await hr.async_resolve_host(["example.local"], 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_called_once_with("example.local", 6052)
|
||||
@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
||||
resolve_addr.return_value = addr_infos
|
||||
with pytest.raises(ResolveAPIError):
|
||||
await hr.async_resolve_host("example.local", 6052)
|
||||
await hr.async_resolve_host(["example.local"], 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("example.com", 6052)
|
||||
ret = await hr.async_resolve_host(["example.com"], 6052)
|
||||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
||||
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_addr.return_value = []
|
||||
with pytest.raises(APIConnectionError):
|
||||
await hr.async_resolve_host("example.com", 6052)
|
||||
await hr.async_resolve_host(["example.com"], 6052)
|
||||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||
@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos)
|
||||
async def test_resolve_host_with_address(resolve_addr, resolve_zc):
|
||||
resolve_zc.return_value = []
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("127.0.0.1", 6052)
|
||||
ret = await hr.async_resolve_host(["127.0.0.1"], 6052)
|
||||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_not_called()
|
||||
|
Loading…
Reference in New Issue
Block a user