mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-06-22 09:56:14 +02:00
Add happy eyeballs support (RFC 8305) (#789)
This commit is contained in:
parent
280b9a7ab7
commit
05ee53c16d
|
@ -280,7 +280,7 @@ class APIClient:
|
|||
"""Set the log name of the device."""
|
||||
resolved_address: str | None = None
|
||||
if self._connection and self._connection.resolved_addr_info:
|
||||
resolved_address = self._connection.resolved_addr_info.sockaddr.address
|
||||
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
|
||||
self.log_name = build_log_name(
|
||||
self.cached_name,
|
||||
self.address,
|
||||
|
|
|
@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
|
|||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import aiohappyeyeballs
|
||||
from google.protobuf import message
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
@ -250,7 +251,7 @@ class APIConnection:
|
|||
self._handshake_complete = False
|
||||
self._debug_enabled = debug_enabled
|
||||
self.received_name: str = ""
|
||||
self.resolved_addr_info: hr.AddrInfo | None = None
|
||||
self.resolved_addr_info: list[hr.AddrInfo] = []
|
||||
|
||||
def set_log_name(self, name: str) -> None:
|
||||
"""Set the friendly log name for this connection."""
|
||||
|
@ -319,7 +320,7 @@ class APIConnection:
|
|||
"""Enable or disable debug logging."""
|
||||
self._debug_enabled = enable
|
||||
|
||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||
async def _connect_resolve_host(self) -> list[hr.AddrInfo]:
|
||||
"""Step 1 in connect process: resolve the address."""
|
||||
try:
|
||||
async with asyncio_timeout(RESOLVE_TIMEOUT):
|
||||
|
@ -333,9 +334,53 @@ class APIConnection:
|
|||
f"Timeout while resolving IP address for {self.log_name}"
|
||||
) from err
|
||||
|
||||
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||
async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addrs,
|
||||
)
|
||||
|
||||
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
|
||||
(
|
||||
addr.family,
|
||||
addr.type,
|
||||
addr.proto,
|
||||
self._params.address,
|
||||
astuple(addr.sockaddr),
|
||||
)
|
||||
for addr in addrs
|
||||
]
|
||||
last_exception: Exception | None = None
|
||||
sock: socket.socket | None = None
|
||||
interleave = 1
|
||||
while addr_infos:
|
||||
try:
|
||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||
sock = await aiohappyeyeballs.start_connection(
|
||||
addr_infos,
|
||||
happy_eyeballs_delay=0.25,
|
||||
interleave=interleave,
|
||||
loop=self._loop,
|
||||
)
|
||||
break
|
||||
except (OSError, asyncio_TimeoutError) as err:
|
||||
last_exception = err
|
||||
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)
|
||||
|
||||
if sock is None:
|
||||
if isinstance(last_exception, asyncio_TimeoutError):
|
||||
raise TimeoutAPIError(
|
||||
f"Timeout while connecting to {addrs}"
|
||||
) from last_exception
|
||||
raise SocketAPIError(
|
||||
f"Error connecting to {addrs}: {last_exception}"
|
||||
) from last_exception
|
||||
|
||||
self._socket = sock
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
@ -343,31 +388,13 @@ class APIConnection:
|
|||
# ram in bytes and we measure ram in megabytes.
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
||||
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
sockaddr = astuple(addr.sockaddr)
|
||||
|
||||
try:
|
||||
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
|
||||
await self._loop.sock_connect(sock, sockaddr)
|
||||
except asyncio_TimeoutError as err:
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
addrs,
|
||||
)
|
||||
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
|
|
|
@ -108,8 +108,10 @@ async def _async_resolve_host_zeroconf(
|
|||
timeout,
|
||||
)
|
||||
addrs: list[AddrInfo] = []
|
||||
for ip in info.ip_addresses_by_version(IPVersion.All):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
|
||||
for ip in info.ip_addresses_by_version(IPVersion.V6Only):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
|
||||
for ip in info.ip_addresses_by_version(IPVersion.V4Only):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
|
||||
return addrs
|
||||
|
||||
|
||||
|
@ -182,7 +184,7 @@ async def async_resolve_host(
|
|||
host: str,
|
||||
port: int,
|
||||
zeroconf_manager: ZeroconfManager | None = None,
|
||||
) -> AddrInfo:
|
||||
) -> list[AddrInfo]:
|
||||
addrs: list[AddrInfo] = []
|
||||
|
||||
zc_error = None
|
||||
|
@ -210,6 +212,4 @@ async def async_resolve_host(
|
|||
raise zc_error
|
||||
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
|
||||
|
||||
# Use first matching result
|
||||
# Future: return all matches and use first working one
|
||||
return addrs[0]
|
||||
return addrs
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
aiohappyeyeballs>=2.3.0
|
||||
protobuf>=3.19.0
|
||||
zeroconf>=0.128.4,<1.0
|
||||
chacha20poly1305-reuseable>=0.12.0
|
||||
|
|
|
@ -39,12 +39,14 @@ def async_zeroconf():
|
|||
@pytest.fixture
|
||||
def resolve_host():
|
||||
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
|
||||
func.return_value = AddrInfo(
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
|
||||
)
|
||||
func.return_value = [
|
||||
AddrInfo(
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
|
||||
)
|
||||
]
|
||||
yield func
|
||||
|
||||
|
||||
|
@ -114,6 +116,13 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
|
|||
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def aiohappyeyeballs_start_connection():
|
||||
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
||||
func.return_value = MagicMock(type=socket.SOCK_STREAM)
|
||||
yield func
|
||||
|
||||
|
||||
def _create_mock_transport_protocol(
|
||||
transport: asyncio.Transport,
|
||||
connected: asyncio.Event,
|
||||
|
@ -128,13 +137,17 @@ def _create_mock_transport_protocol(
|
|||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||
async def plaintext_connect_task_no_login(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
conn: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -146,12 +159,16 @@ async def plaintext_connect_task_no_login(
|
|||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
|
||||
async def plaintext_connect_task_no_login_with_expected_name(
|
||||
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
|
||||
conn_with_expected_name: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -165,12 +182,16 @@ async def plaintext_connect_task_no_login_with_expected_name(
|
|||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
||||
async def plaintext_connect_task_with_login(
|
||||
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
|
||||
conn_with_password: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -182,7 +203,7 @@ async def plaintext_connect_task_with_login(
|
|||
|
||||
@pytest_asyncio.fixture(name="api_client")
|
||||
async def api_client(
|
||||
resolve_host, socket_socket, event_loop
|
||||
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
|
@ -193,7 +214,7 @@ async def api_client(
|
|||
password=None,
|
||||
)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
|
|
@ -194,14 +194,14 @@ async def test_connect_backwards_compat() -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> None:
|
||||
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error(
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
"""Verify finish_connect re-wraps exceptions as UnhandledAPIError."""
|
||||
|
||||
cli = APIClient("1.2.3.4", 1234, None)
|
||||
loop = asyncio.get_event_loop()
|
||||
with patch(
|
||||
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
|
||||
), patch.object(loop, "sock_connect"):
|
||||
asyncio.get_event_loop()
|
||||
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
|
||||
await cli.start_connection()
|
||||
|
||||
with patch.object(
|
||||
|
@ -217,9 +217,12 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
|
|||
async def test_connection_released_if_connecting_is_cancelled() -> None:
|
||||
"""Verify connection is unset if connecting is cancelled."""
|
||||
cli = APIClient("1.2.3.4", 1234, None)
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=partial(asyncio.sleep, 1),
|
||||
):
|
||||
start_task = asyncio.create_task(cli.start_connection())
|
||||
await asyncio.sleep(0)
|
||||
assert cli._connection is not None
|
||||
|
@ -229,9 +232,9 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
|||
await start_task
|
||||
assert cli._connection is None
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
|
||||
), patch.object(loop, "sock_connect"):
|
||||
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
|
||||
):
|
||||
await cli.start_connection()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
@ -252,8 +255,9 @@ async def test_request_while_handshaking(event_loop) -> None:
|
|||
pass
|
||||
|
||||
cli = PatchableApiClient("host", 1234, None)
|
||||
with patch.object(
|
||||
event_loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=partial(asyncio.sleep, 1),
|
||||
), patch.object(cli, "finish_connection"):
|
||||
connect_task = asyncio.create_task(cli.connect())
|
||||
|
||||
|
|
|
@ -241,14 +241,15 @@ async def test_start_connection_times_out(
|
|||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection timing out."""
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.get_event_loop()
|
||||
|
||||
async def _mock_socket_connect(*args, **kwargs):
|
||||
await asyncio.sleep(500)
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=_mock_socket_connect), patch(
|
||||
"aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0
|
||||
):
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=_mock_socket_connect,
|
||||
), patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
@ -267,9 +268,12 @@ async def test_start_connection_os_error(
|
|||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection has an OSError."""
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=OSError("Socket error")):
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=OSError("Socket error"),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Socket error"):
|
||||
|
@ -284,9 +288,12 @@ async def test_start_connection_is_cancelled(
|
|||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection is cancelled."""
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=asyncio.CancelledError):
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=asyncio.CancelledError,
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Starting connection cancelled"):
|
||||
|
@ -551,7 +558,7 @@ async def test_force_disconnect_fails(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resolver_times_out(
|
||||
conn: APIConnection, socket_socket, event_loop
|
||||
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
@ -559,7 +566,7 @@ async def test_connect_resolver_times_out(
|
|||
with patch(
|
||||
"aioesphomeapi.host_resolver.async_resolve_host",
|
||||
side_effect=asyncio.TimeoutError,
|
||||
), patch.object(event_loop, "sock_connect"), patch.object(
|
||||
), patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -575,6 +582,7 @@ async def test_disconnect_fails_to_send_response(
|
|||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = MagicMock()
|
||||
|
@ -590,7 +598,7 @@ async def test_disconnect_fails_to_send_response(
|
|||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -625,6 +633,7 @@ async def test_disconnect_success_case(
|
|||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = MagicMock()
|
||||
|
@ -640,7 +649,7 @@ async def test_disconnect_success_case(
|
|||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
|
|
@ -39,9 +39,9 @@ def addr_infos():
|
|||
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ipv6,
|
||||
info.ip_addresses_by_version.side_effect = [
|
||||
[ip_address(b"\n\x00\x00*")],
|
||||
[ipv6],
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
|
@ -59,9 +59,9 @@ async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
|
|||
zeroconf_manager = ZeroconfManager()
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ipv6,
|
||||
info.ip_addresses_by_version.side_effect = [
|
||||
[ip_address(b"\n\x00\x00*")],
|
||||
[ipv6],
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
|
@ -144,7 +144,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
|||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_not_called()
|
||||
assert ret == addr_infos[0]
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -157,7 +157,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
|||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_called_once_with("example.local", 6052)
|
||||
assert ret == addr_infos[0]
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -178,7 +178,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
|||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||
assert ret == addr_infos[0]
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -203,12 +203,14 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc):
|
|||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_not_called()
|
||||
assert ret == hr.AddrInfo(
|
||||
family=socket.AddressFamily.AF_INET,
|
||||
type=socket.SocketKind.SOCK_STREAM,
|
||||
proto=6,
|
||||
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
|
||||
)
|
||||
assert ret == [
|
||||
hr.AddrInfo(
|
||||
family=socket.AddressFamily.AF_INET,
|
||||
type=socket.SocketKind.SOCK_STREAM,
|
||||
proto=6,
|
||||
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -30,7 +30,11 @@ from .common import (
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
|
||||
async def test_log_runner(
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
conn: APIConnection,
|
||||
aiohappyeyeballs_start_connection,
|
||||
):
|
||||
"""Test the log runner logic."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
|
@ -69,7 +73,7 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
|||
await original_subscribe_logs(*args, **kwargs)
|
||||
subscribed.set()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
|
||||
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
|
||||
|
@ -96,6 +100,7 @@ async def test_log_runner_reconnects_on_disconnect(
|
|||
event_loop: asyncio.AbstractEventLoop,
|
||||
conn: APIConnection,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
"""Test the log runner reconnects on disconnect."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -135,7 +140,7 @@ async def test_log_runner_reconnects_on_disconnect(
|
|||
await original_subscribe_logs(*args, **kwargs)
|
||||
subscribed.set()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
|
||||
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
|
||||
|
@ -173,6 +178,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
|
|||
event_loop: asyncio.AbstractEventLoop,
|
||||
conn: APIConnection,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
"""Test the log runner reconnects on subscribe failure."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -214,7 +220,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
|
|||
with patch.object(
|
||||
cli, "disconnect", partial(cli.disconnect, force=True)
|
||||
), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli):
|
||||
with patch.object(loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
|
||||
|
@ -227,7 +233,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
|
|||
|
||||
assert cli._connection is None
|
||||
|
||||
with patch.object(loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
), patch.object(cli, "subscribe_logs"):
|
||||
connected.clear()
|
||||
|
|
|
@ -672,7 +672,9 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventLoop):
|
||||
async def test_handling_unexpected_disconnect(
|
||||
event_loop: asyncio.AbstractEventLoop, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test the disconnect callback fires with expected_disconnect=False."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
|
@ -710,7 +712,7 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
|||
name="fake",
|
||||
)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -726,7 +728,7 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
|||
assert cli._connection.is_connected is True
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
@ -746,7 +748,9 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backoff_on_encryption_error(
|
||||
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
"""Test we backoff on encryption error."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -785,7 +789,7 @@ async def test_backoff_on_encryption_error(
|
|||
name="fake",
|
||||
)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
|
|
Loading…
Reference in New Issue
Block a user