mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-12 13:29:49 +01:00
Add support for passing multiple addresses to the client (#796)
This commit is contained in:
parent
4668b1ff54
commit
de1d08493d
@ -220,6 +220,7 @@ class APIClient:
|
|||||||
zeroconf_instance: ZeroconfInstanceType | None = None,
|
zeroconf_instance: ZeroconfInstanceType | None = None,
|
||||||
noise_psk: str | None = None,
|
noise_psk: str | None = None,
|
||||||
expected_name: str | None = None,
|
expected_name: str | None = None,
|
||||||
|
addresses: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a client, this object is shared across sessions.
|
"""Create a client, this object is shared across sessions.
|
||||||
|
|
||||||
@ -235,10 +236,14 @@ class APIClient:
|
|||||||
:param expected_name: Require the devices name to match the given expected name.
|
:param expected_name: Require the devices name to match the given expected name.
|
||||||
Can be used to prevent accidentally connecting to a different device if
|
Can be used to prevent accidentally connecting to a different device if
|
||||||
IP passed as address but DHCP reassigned IP.
|
IP passed as address but DHCP reassigned IP.
|
||||||
|
:param addresses: Optional list of IP addresses to connect to which takes
|
||||||
|
precedence over the address parameter. This is most commonly used when
|
||||||
|
the device has dual stack IPv4 and IPv6 addresses and you do not know
|
||||||
|
which one to connect to.
|
||||||
"""
|
"""
|
||||||
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
self._params = ConnectionParams(
|
self._params = ConnectionParams(
|
||||||
address=str(address),
|
addresses=addresses if addresses else [str(address)],
|
||||||
port=port,
|
port=port,
|
||||||
password=password,
|
password=password,
|
||||||
client_info=client_info,
|
client_info=client_info,
|
||||||
@ -274,17 +279,17 @@ class APIClient:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def address(self) -> str:
|
def address(self) -> str:
|
||||||
return self._params.address
|
return self._params.addresses[0]
|
||||||
|
|
||||||
def _set_log_name(self) -> None:
|
def _set_log_name(self) -> None:
|
||||||
"""Set the log name of the device."""
|
"""Set the log name of the device."""
|
||||||
resolved_address: str | None = None
|
connected_address: str | None = None
|
||||||
if self._connection and self._connection.resolved_addr_info:
|
if self._connection and self._connection.connected_address:
|
||||||
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
|
connected_address = self._connection.connected_address
|
||||||
self.log_name = build_log_name(
|
self.log_name = build_log_name(
|
||||||
self.cached_name,
|
self.cached_name,
|
||||||
self.address,
|
self._params.addresses,
|
||||||
resolved_address,
|
connected_address,
|
||||||
)
|
)
|
||||||
if self._connection:
|
if self._connection:
|
||||||
self._connection.set_log_name(self.log_name)
|
self._connection.set_log_name(self.log_name)
|
||||||
@ -328,8 +333,8 @@ class APIClient:
|
|||||||
self.log_name,
|
self.log_name,
|
||||||
)
|
)
|
||||||
await self._execute_connection_coro(self._connection.start_connection())
|
await self._execute_connection_coro(self._connection.start_connection())
|
||||||
# If we resolved the address, we should set the log name now
|
# If we connected, we should set the log name now
|
||||||
if self._connection.resolved_addr_info:
|
if self._connection.connected_address:
|
||||||
self._set_log_name()
|
self._set_log_name()
|
||||||
|
|
||||||
async def finish_connection(
|
async def finish_connection(
|
||||||
|
@ -74,7 +74,7 @@ cdef object _handle_complex_message
|
|||||||
|
|
||||||
@cython.dataclasses.dataclass
|
@cython.dataclasses.dataclass
|
||||||
cdef class ConnectionParams:
|
cdef class ConnectionParams:
|
||||||
cdef public str address
|
cdef public list addresses
|
||||||
cdef public object port
|
cdef public object port
|
||||||
cdef public object password
|
cdef public object password
|
||||||
cdef public object client_info
|
cdef public object client_info
|
||||||
@ -108,7 +108,7 @@ cdef class APIConnection:
|
|||||||
cdef bint _handshake_complete
|
cdef bint _handshake_complete
|
||||||
cdef bint _debug_enabled
|
cdef bint _debug_enabled
|
||||||
cdef public str received_name
|
cdef public str received_name
|
||||||
cdef public object resolved_addr_info
|
cdef public str connected_address
|
||||||
|
|
||||||
cpdef void send_message(self, object msg)
|
cpdef void send_message(self, object msg)
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ _float = float
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams:
|
class ConnectionParams:
|
||||||
address: str
|
addresses: list[str]
|
||||||
port: int
|
port: int
|
||||||
password: str | None
|
password: str | None
|
||||||
client_info: str
|
client_info: str
|
||||||
@ -207,7 +207,7 @@ class APIConnection:
|
|||||||
"_handshake_complete",
|
"_handshake_complete",
|
||||||
"_debug_enabled",
|
"_debug_enabled",
|
||||||
"received_name",
|
"received_name",
|
||||||
"resolved_addr_info",
|
"connected_address",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -230,7 +230,7 @@ class APIConnection:
|
|||||||
# Message handlers currently subscribed to incoming messages
|
# Message handlers currently subscribed to incoming messages
|
||||||
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
|
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
|
||||||
# The friendly name to show for this connection in the logs
|
# The friendly name to show for this connection in the logs
|
||||||
self.log_name = log_name or params.address
|
self.log_name = log_name or ",".join(params.addresses)
|
||||||
|
|
||||||
# futures currently subscribed to exceptions in the read task
|
# futures currently subscribed to exceptions in the read task
|
||||||
self._read_exception_futures: set[asyncio.Future[None]] = set()
|
self._read_exception_futures: set[asyncio.Future[None]] = set()
|
||||||
@ -251,7 +251,7 @@ class APIConnection:
|
|||||||
self._handshake_complete = False
|
self._handshake_complete = False
|
||||||
self._debug_enabled = debug_enabled
|
self._debug_enabled = debug_enabled
|
||||||
self.received_name: str = ""
|
self.received_name: str = ""
|
||||||
self.resolved_addr_info: list[hr.AddrInfo] = []
|
self.connected_address: str | None = None
|
||||||
|
|
||||||
def set_log_name(self, name: str) -> None:
|
def set_log_name(self, name: str) -> None:
|
||||||
"""Set the friendly log name for this connection."""
|
"""Set the friendly log name for this connection."""
|
||||||
@ -325,7 +325,7 @@ class APIConnection:
|
|||||||
try:
|
try:
|
||||||
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
||||||
return await hr.async_resolve_host(
|
return await hr.async_resolve_host(
|
||||||
self._params.address,
|
self._params.addresses,
|
||||||
self._params.port,
|
self._params.port,
|
||||||
self._params.zeroconf_manager,
|
self._params.zeroconf_manager,
|
||||||
)
|
)
|
||||||
@ -338,11 +338,9 @@ class APIConnection:
|
|||||||
"""Step 2 in connect process: connect the socket."""
|
"""Step 2 in connect process: connect the socket."""
|
||||||
if self._debug_enabled:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Connecting to %s:%s (%s)",
|
"%s: Connecting to %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
self._params.address,
|
", ".join(str(addr.sockaddr) for addr in addrs),
|
||||||
self._params.port,
|
|
||||||
addrs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
|
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
|
||||||
@ -350,7 +348,7 @@ class APIConnection:
|
|||||||
addr.family,
|
addr.family,
|
||||||
addr.type,
|
addr.type,
|
||||||
addr.proto,
|
addr.proto,
|
||||||
self._params.address,
|
"",
|
||||||
astuple(addr.sockaddr),
|
astuple(addr.sockaddr),
|
||||||
)
|
)
|
||||||
for addr in addrs
|
for addr in addrs
|
||||||
@ -361,9 +359,11 @@ class APIConnection:
|
|||||||
while addr_infos:
|
while addr_infos:
|
||||||
try:
|
try:
|
||||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||||
|
# Devices are likely on the local network so we
|
||||||
|
# only use a 100ms happy eyeballs delay
|
||||||
sock = await aiohappyeyeballs.start_connection(
|
sock = await aiohappyeyeballs.start_connection(
|
||||||
addr_infos,
|
addr_infos,
|
||||||
happy_eyeballs_delay=0.25,
|
happy_eyeballs_delay=0.1,
|
||||||
interleave=interleave,
|
interleave=interleave,
|
||||||
loop=self._loop,
|
loop=self._loop,
|
||||||
)
|
)
|
||||||
@ -387,14 +387,14 @@ class APIConnection:
|
|||||||
# Try to reduce the pressure on esphome device as it measures
|
# Try to reduce the pressure on esphome device as it measures
|
||||||
# ram in bytes and we measure ram in megabytes.
|
# ram in bytes and we measure ram in megabytes.
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||||
|
self.connected_address = sock.getpeername()[0]
|
||||||
|
|
||||||
if self._debug_enabled:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Opened socket to %s:%s (%s)",
|
"%s: Opened socket to %s:%s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
self._params.address,
|
self.connected_address,
|
||||||
self._params.port,
|
self._params.port,
|
||||||
addrs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _connect_init_frame_helper(self) -> None:
|
async def _connect_init_frame_helper(self) -> None:
|
||||||
@ -567,8 +567,7 @@ class APIConnection:
|
|||||||
|
|
||||||
async def _do_connect(self) -> None:
|
async def _do_connect(self) -> None:
|
||||||
"""Do the actual connect process."""
|
"""Do the actual connect process."""
|
||||||
self.resolved_addr_info = await self._connect_resolve_host()
|
await self._connect_socket_connect(await self._connect_resolve_host())
|
||||||
await self._connect_socket_connect(self.resolved_addr_info)
|
|
||||||
|
|
||||||
async def start_connection(self) -> None:
|
async def start_connection(self) -> None:
|
||||||
"""Start the connection process.
|
"""Start the connection process.
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -181,35 +180,46 @@ def _async_ip_address_to_addrs(
|
|||||||
|
|
||||||
|
|
||||||
async def async_resolve_host(
|
async def async_resolve_host(
|
||||||
host: str,
|
hosts: list[str],
|
||||||
port: int,
|
port: int,
|
||||||
zeroconf_manager: ZeroconfManager | None = None,
|
zeroconf_manager: ZeroconfManager | None = None,
|
||||||
) -> list[AddrInfo]:
|
) -> list[AddrInfo]:
|
||||||
addrs: list[AddrInfo] = []
|
addrs: list[AddrInfo] = []
|
||||||
|
zc_error: Exception | None = None
|
||||||
|
|
||||||
zc_error = None
|
for host in hosts:
|
||||||
if host_is_name_part(host) or address_is_local(host):
|
host_addrs: list[AddrInfo] = []
|
||||||
name = host.partition(".")[0]
|
host_is_local_name = host_is_name_part(host) or address_is_local(host)
|
||||||
try:
|
|
||||||
addrs.extend(
|
if host_is_local_name:
|
||||||
await _async_resolve_host_zeroconf(
|
name = host.partition(".")[0]
|
||||||
name, port, zeroconf_manager=zeroconf_manager
|
try:
|
||||||
|
host_addrs.extend(
|
||||||
|
await _async_resolve_host_zeroconf(
|
||||||
|
name, port, zeroconf_manager=zeroconf_manager
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except ResolveAPIError as err:
|
||||||
except ResolveAPIError as err:
|
zc_error = err
|
||||||
zc_error = err
|
|
||||||
|
|
||||||
else:
|
if not host_is_local_name:
|
||||||
with contextlib.suppress(ValueError):
|
try:
|
||||||
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||||
|
except ValueError:
|
||||||
|
# Not an IP address
|
||||||
|
pass
|
||||||
|
|
||||||
if not addrs:
|
if not host_addrs:
|
||||||
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||||
|
|
||||||
|
addrs.extend(host_addrs)
|
||||||
|
|
||||||
if not addrs:
|
if not addrs:
|
||||||
if zc_error:
|
if zc_error:
|
||||||
# Only show ZC error if getaddrinfo also didn't work
|
# Only show ZC error if getaddrinfo also didn't work
|
||||||
raise zc_error
|
raise zc_error
|
||||||
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
|
raise ResolveAPIError(
|
||||||
|
f"Could not resolve host {hosts} - got no results from OS"
|
||||||
|
)
|
||||||
|
|
||||||
return addrs
|
return addrs
|
||||||
|
@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool:
|
|||||||
return address.removesuffix(".").endswith(".local")
|
return address.removesuffix(".").endswith(".local")
|
||||||
|
|
||||||
|
|
||||||
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
|
def build_log_name(
|
||||||
|
name: str | None, addresses: list[str], connected_address: str | None
|
||||||
|
) -> str:
|
||||||
"""Return a log name for a connection."""
|
"""Return a log name for a connection."""
|
||||||
if not name and address_is_local(address) or host_is_name_part(address):
|
preferred_address = connected_address
|
||||||
name = address.partition(".")[0]
|
for address in addresses:
|
||||||
preferred_address = resolved_address or address
|
if not name and address_is_local(address) or host_is_name_part(address):
|
||||||
|
name = address.partition(".")[0]
|
||||||
|
elif not preferred_address:
|
||||||
|
preferred_address = address
|
||||||
|
if not preferred_address:
|
||||||
|
return name or addresses[0]
|
||||||
if (
|
if (
|
||||||
name
|
name
|
||||||
and name != preferred_address
|
and name != preferred_address
|
||||||
|
@ -6,7 +6,7 @@ import socket
|
|||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, create_autospec, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -50,12 +50,6 @@ def resolve_host():
|
|||||||
yield func
|
yield func
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def socket_socket():
|
|
||||||
with patch("socket.socket") as func:
|
|
||||||
yield func
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patchable_api_client() -> APIClient:
|
def patchable_api_client() -> APIClient:
|
||||||
class PatchableAPIClient(APIClient):
|
class PatchableAPIClient(APIClient):
|
||||||
@ -71,7 +65,7 @@ def patchable_api_client() -> APIClient:
|
|||||||
|
|
||||||
def get_mock_connection_params() -> ConnectionParams:
|
def get_mock_connection_params() -> ConnectionParams:
|
||||||
return ConnectionParams(
|
return ConnectionParams(
|
||||||
address="fake.address",
|
addresses=["fake.address"],
|
||||||
port=6052,
|
port=6052,
|
||||||
password=None,
|
password=None,
|
||||||
client_info="Tests client",
|
client_info="Tests client",
|
||||||
@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def aiohappyeyeballs_start_connection():
|
def aiohappyeyeballs_start_connection():
|
||||||
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
||||||
func.return_value = MagicMock(type=socket.SOCK_STREAM)
|
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||||
|
mock_socket.type = socket.SOCK_STREAM
|
||||||
|
mock_socket.fileno.return_value = 1
|
||||||
|
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
|
||||||
|
func.return_value = mock_socket
|
||||||
yield func
|
yield func
|
||||||
|
|
||||||
|
|
||||||
@ -139,7 +137,6 @@ def _create_mock_transport_protocol(
|
|||||||
async def plaintext_connect_task_no_login(
|
async def plaintext_connect_task_no_login(
|
||||||
conn: APIConnection,
|
conn: APIConnection,
|
||||||
resolve_host,
|
resolve_host,
|
||||||
socket_socket,
|
|
||||||
event_loop,
|
event_loop,
|
||||||
aiohappyeyeballs_start_connection,
|
aiohappyeyeballs_start_connection,
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login(
|
|||||||
async def plaintext_connect_task_no_login_with_expected_name(
|
async def plaintext_connect_task_no_login_with_expected_name(
|
||||||
conn_with_expected_name: APIConnection,
|
conn_with_expected_name: APIConnection,
|
||||||
resolve_host,
|
resolve_host,
|
||||||
socket_socket,
|
|
||||||
event_loop,
|
event_loop,
|
||||||
aiohappyeyeballs_start_connection,
|
aiohappyeyeballs_start_connection,
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
|
|||||||
async def plaintext_connect_task_with_login(
|
async def plaintext_connect_task_with_login(
|
||||||
conn_with_password: APIConnection,
|
conn_with_password: APIConnection,
|
||||||
resolve_host,
|
resolve_host,
|
||||||
socket_socket,
|
|
||||||
event_loop,
|
event_loop,
|
||||||
aiohappyeyeballs_start_connection,
|
aiohappyeyeballs_start_connection,
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login(
|
|||||||
|
|
||||||
@pytest_asyncio.fixture(name="api_client")
|
@pytest_asyncio.fixture(name="api_client")
|
||||||
async def api_client(
|
async def api_client(
|
||||||
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
resolve_host, event_loop, aiohappyeyeballs_start_connection
|
||||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||||
protocol: APIPlaintextFrameHelper | None = None
|
protocol: APIPlaintextFrameHelper | None = None
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
|
@ -199,7 +199,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_plaintext_frame_helper(
|
@pytest.mark.asyncio
|
||||||
|
async def test_plaintext_frame_helper(
|
||||||
in_bytes: bytes, pkt_data: bytes, pkt_type: int
|
in_bytes: bytes, pkt_data: bytes, pkt_type: int
|
||||||
) -> None:
|
) -> None:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@ -592,7 +593,9 @@ async def test_noise_frame_helper_bad_encryption(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
async def test_init_plaintext_with_wrong_preamble(
|
||||||
|
conn: APIConnection, aiohappyeyeballs_start_connection
|
||||||
|
):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
protocol = get_mock_protocol(conn)
|
protocol = get_mock_protocol(conn)
|
||||||
with patch.object(loop, "create_connection") as create_connection:
|
with patch.object(loop, "create_connection") as create_connection:
|
||||||
|
@ -4,9 +4,10 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -169,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
|
|||||||
client._connection.api_version = version
|
client._connection.api_version = version
|
||||||
|
|
||||||
|
|
||||||
def test_expected_name(auth_client: APIClient) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_expected_name(auth_client: APIClient) -> None:
|
||||||
"""Ensure expected name can be set externally."""
|
"""Ensure expected name can be set externally."""
|
||||||
assert auth_client.expected_name is None
|
assert auth_client.expected_name is None
|
||||||
auth_client.expected_name = "awesome"
|
auth_client.expected_name = "awesome"
|
||||||
@ -219,9 +221,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
|||||||
cli = APIClient("1.2.3.4", 1234, None)
|
cli = APIClient("1.2.3.4", 1234, None)
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
|
|
||||||
|
async def _start_connection_with_delay(*args, **kwargs):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||||
|
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||||
|
return mock_socket
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||||
side_effect=partial(asyncio.sleep, 1),
|
_start_connection_with_delay,
|
||||||
):
|
):
|
||||||
start_task = asyncio.create_task(cli.start_connection())
|
start_task = asyncio.create_task(cli.start_connection())
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@ -232,8 +240,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
|||||||
await start_task
|
await start_task
|
||||||
assert cli._connection is None
|
assert cli._connection is None
|
||||||
|
|
||||||
|
async def _start_connection_without_delay(*args, **kwargs):
|
||||||
|
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||||
|
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||||
|
return mock_socket
|
||||||
|
|
||||||
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
|
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
|
||||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
|
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||||
|
_start_connection_without_delay,
|
||||||
):
|
):
|
||||||
await cli.start_connection()
|
await cli.start_connection()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@ -894,7 +908,7 @@ async def test_noise_psk_handles_subclassed_string():
|
|||||||
)
|
)
|
||||||
# Make sure its not a subclassed string
|
# Make sure its not a subclassed string
|
||||||
assert type(cli._params.noise_psk) is str
|
assert type(cli._params.noise_psk) is str
|
||||||
assert type(cli._params.address) is str
|
assert type(cli._params.addresses[0]) is str
|
||||||
assert type(cli._params.expected_name) is str
|
assert type(cli._params.expected_name) is str
|
||||||
|
|
||||||
rl = ReconnectLogic(
|
rl = ReconnectLogic(
|
||||||
@ -930,7 +944,7 @@ async def test_no_noise_psk():
|
|||||||
)
|
)
|
||||||
# Make sure its not a subclassed string
|
# Make sure its not a subclassed string
|
||||||
assert cli._params.noise_psk is None
|
assert cli._params.noise_psk is None
|
||||||
assert type(cli._params.address) is str
|
assert type(cli._params.addresses[0]) is str
|
||||||
assert type(cli._params.expected_name) is str
|
assert type(cli._params.expected_name) is str
|
||||||
|
|
||||||
|
|
||||||
@ -945,7 +959,7 @@ async def test_empty_noise_psk_or_expected_name():
|
|||||||
expected_name="",
|
expected_name="",
|
||||||
)
|
)
|
||||||
assert cli._params.noise_psk is None
|
assert cli._params.noise_psk is None
|
||||||
assert type(cli._params.address) is str
|
assert type(cli._params.addresses[0]) is str
|
||||||
assert cli._params.expected_name is None
|
assert cli._params.expected_name is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ async def test_plaintext_connection(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_connection_socket_error(
|
async def test_start_connection_socket_error(
|
||||||
conn: APIConnection, resolve_host, socket_socket
|
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||||
):
|
):
|
||||||
"""Test handling of socket error during start connection."""
|
"""Test handling of socket error during start connection."""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_connection_times_out(
|
async def test_start_connection_times_out(
|
||||||
conn: APIConnection, resolve_host, socket_socket
|
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||||
):
|
):
|
||||||
"""Test handling of start connection timing out."""
|
"""Test handling of start connection timing out."""
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_connection_os_error(
|
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
|
||||||
conn: APIConnection, resolve_host, socket_socket
|
|
||||||
):
|
|
||||||
"""Test handling of start connection has an OSError."""
|
"""Test handling of start connection has an OSError."""
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
|
|
||||||
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_connection_is_cancelled(
|
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
|
||||||
conn: APIConnection, resolve_host, socket_socket
|
|
||||||
):
|
|
||||||
"""Test handling of start connection is cancelled."""
|
"""Test handling of start connection is cancelled."""
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
|
|
||||||
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_finish_connection_is_cancelled(
|
async def test_finish_connection_is_cancelled(
|
||||||
conn: APIConnection, resolve_host, socket_socket
|
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||||
):
|
):
|
||||||
"""Test handling of finishing connection being cancelled."""
|
"""Test handling of finishing connection being cancelled."""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
|
|||||||
async def test_plaintext_connection_fails_handshake(
|
async def test_plaintext_connection_fails_handshake(
|
||||||
conn: APIConnection,
|
conn: APIConnection,
|
||||||
resolve_host: AsyncMock,
|
resolve_host: AsyncMock,
|
||||||
socket_socket: MagicMock,
|
aiohappyeyeballs_start_connection: MagicMock,
|
||||||
exception_map: tuple[Exception, Exception],
|
exception_map: tuple[Exception, Exception],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the frame helper is closed before the underlying socket.
|
"""Test that the frame helper is closed before the underlying socket.
|
||||||
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect_resolver_times_out(
|
async def test_connect_resolver_times_out(
|
||||||
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
connected = asyncio.Event()
|
connected = asyncio.Event()
|
||||||
@ -571,7 +567,8 @@ async def test_connect_resolver_times_out(
|
|||||||
"create_connection",
|
"create_connection",
|
||||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
), pytest.raises(
|
), pytest.raises(
|
||||||
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
|
ResolveAPIError,
|
||||||
|
match="Timeout while resolving IP address for fake.address",
|
||||||
):
|
):
|
||||||
await connect(conn, login=False)
|
await connect(conn, login=False)
|
||||||
|
|
||||||
@ -581,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
|
|||||||
connection_params: ConnectionParams,
|
connection_params: ConnectionParams,
|
||||||
event_loop: asyncio.AbstractEventLoop,
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
resolve_host,
|
resolve_host,
|
||||||
socket_socket,
|
|
||||||
aiohappyeyeballs_start_connection,
|
aiohappyeyeballs_start_connection,
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -632,7 +628,6 @@ async def test_disconnect_success_case(
|
|||||||
connection_params: ConnectionParams,
|
connection_params: ConnectionParams,
|
||||||
event_loop: asyncio.AbstractEventLoop,
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
resolve_host,
|
resolve_host,
|
||||||
socket_socket,
|
|
||||||
aiohappyeyeballs_start_connection,
|
aiohappyeyeballs_start_connection,
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
|
|||||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||||
side_effect=Exception("no buffers"),
|
side_effect=Exception("no buffers"),
|
||||||
), pytest.raises(ResolveAPIError, match="no buffers"):
|
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||||
await hr.async_resolve_host("asdf.local", 6052)
|
await hr.async_resolve_host(["asdf.local"], 6052)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop):
|
|||||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||||
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||||
resolve_zc.return_value = addr_infos
|
resolve_zc.return_value = addr_infos
|
||||||
ret = await hr.async_resolve_host("example.local", 6052)
|
ret = await hr.async_resolve_host(["example.local"], 6052)
|
||||||
|
|
||||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||||
resolve_addr.assert_not_called()
|
resolve_addr.assert_not_called()
|
||||||
@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
|||||||
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||||
resolve_zc.return_value = []
|
resolve_zc.return_value = []
|
||||||
resolve_addr.return_value = addr_infos
|
resolve_addr.return_value = addr_infos
|
||||||
ret = await hr.async_resolve_host("example.local", 6052)
|
ret = await hr.async_resolve_host(["example.local"], 6052)
|
||||||
|
|
||||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||||
resolve_addr.assert_called_once_with("example.local", 6052)
|
resolve_addr.assert_called_once_with("example.local", 6052)
|
||||||
@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
|||||||
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
||||||
resolve_addr.return_value = addr_infos
|
resolve_addr.return_value = addr_infos
|
||||||
with pytest.raises(ResolveAPIError):
|
with pytest.raises(ResolveAPIError):
|
||||||
await hr.async_resolve_host("example.local", 6052)
|
await hr.async_resolve_host(["example.local"], 6052)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
|||||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||||
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
||||||
resolve_addr.return_value = addr_infos
|
resolve_addr.return_value = addr_infos
|
||||||
ret = await hr.async_resolve_host("example.com", 6052)
|
ret = await hr.async_resolve_host(["example.com"], 6052)
|
||||||
|
|
||||||
resolve_zc.assert_not_called()
|
resolve_zc.assert_not_called()
|
||||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||||
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
|||||||
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
|
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
|
||||||
resolve_addr.return_value = []
|
resolve_addr.return_value = []
|
||||||
with pytest.raises(APIConnectionError):
|
with pytest.raises(APIConnectionError):
|
||||||
await hr.async_resolve_host("example.com", 6052)
|
await hr.async_resolve_host(["example.com"], 6052)
|
||||||
|
|
||||||
resolve_zc.assert_not_called()
|
resolve_zc.assert_not_called()
|
||||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||||
@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos)
|
|||||||
async def test_resolve_host_with_address(resolve_addr, resolve_zc):
|
async def test_resolve_host_with_address(resolve_addr, resolve_zc):
|
||||||
resolve_zc.return_value = []
|
resolve_zc.return_value = []
|
||||||
resolve_addr.return_value = addr_infos
|
resolve_addr.return_value = addr_infos
|
||||||
ret = await hr.async_resolve_host("127.0.0.1", 6052)
|
ret = await hr.async_resolve_host(["127.0.0.1"], 6052)
|
||||||
|
|
||||||
resolve_zc.assert_not_called()
|
resolve_zc.assert_not_called()
|
||||||
resolve_addr.assert_not_called()
|
resolve_addr.assert_not_called()
|
||||||
|
Loading…
Reference in New Issue
Block a user