Add happy eyeballs support (RFC 8305) (#789)

This commit is contained in:
J. Nick Koston 2023-12-12 07:24:31 -10:00 committed by GitHub
parent 280b9a7ab7
commit 05ee53c16d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 167 additions and 93 deletions

View File

@ -280,7 +280,7 @@ class APIClient:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info.sockaddr.address
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,

View File

@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable
import aiohappyeyeballs
from google.protobuf import message
import aioesphomeapi.host_resolver as hr
@ -250,7 +251,7 @@ class APIConnection:
self._handshake_complete = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
self.resolved_addr_info: list[hr.AddrInfo] = []
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
@ -319,7 +320,7 @@ class APIConnection:
"""Enable or disable debug logging."""
self._debug_enabled = enable
async def _connect_resolve_host(self) -> hr.AddrInfo:
async def _connect_resolve_host(self) -> list[hr.AddrInfo]:
"""Step 1 in connect process: resolve the address."""
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
@ -333,9 +334,53 @@ class APIConnection:
f"Timeout while resolving IP address for {self.log_name}"
) from err
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
"""Step 2 in connect process: connect the socket."""
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addrs,
)
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
(
addr.family,
addr.type,
addr.proto,
self._params.address,
astuple(addr.sockaddr),
)
for addr in addrs
]
last_exception: Exception | None = None
sock: socket.socket | None = None
interleave = 1
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
interleave=interleave,
loop=self._loop,
)
break
except (OSError, asyncio_TimeoutError) as err:
last_exception = err
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)
if sock is None:
if isinstance(last_exception, asyncio_TimeoutError):
raise TimeoutAPIError(
f"Timeout while connecting to {addrs}"
) from last_exception
raise SocketAPIError(
f"Error connecting to {addrs}: {last_exception}"
) from last_exception
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@ -343,31 +388,13 @@ class APIConnection:
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
sockaddr = astuple(addr.sockaddr)
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
await self._loop.sock_connect(sock, sockaddr)
except asyncio_TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
addrs,
)
async def _connect_init_frame_helper(self) -> None:

View File

@ -108,8 +108,10 @@ async def _async_resolve_host_zeroconf(
timeout,
)
addrs: list[AddrInfo] = []
for ip in info.ip_addresses_by_version(IPVersion.All):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
for ip in info.ip_addresses_by_version(IPVersion.V6Only):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
for ip in info.ip_addresses_by_version(IPVersion.V4Only):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
return addrs
@ -182,7 +184,7 @@ async def async_resolve_host(
host: str,
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> AddrInfo:
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []
zc_error = None
@ -210,6 +212,4 @@ async def async_resolve_host(
raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
# Use first matching result
# Future: return all matches and use first working one
return addrs[0]
return addrs

View File

@ -1,3 +1,4 @@
aiohappyeyeballs>=2.3.0
protobuf>=3.19.0
zeroconf>=0.128.4,<1.0
chacha20poly1305-reuseable>=0.12.0

View File

@ -39,12 +39,14 @@ def async_zeroconf():
@pytest.fixture
def resolve_host():
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
func.return_value = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
func.return_value = [
AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
]
yield func
@ -114,6 +116,13 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM)
yield func
def _create_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
@ -128,13 +137,17 @@ def _create_mock_transport_protocol(
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login(
conn: APIConnection, resolve_host, socket_socket, event_loop
conn: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -146,12 +159,16 @@ async def plaintext_connect_task_no_login(
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
conn_with_expected_name: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -165,12 +182,16 @@ async def plaintext_connect_task_no_login_with_expected_name(
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
conn_with_password: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -182,7 +203,7 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, socket_socket, event_loop
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
@ -193,7 +214,7 @@ async def api_client(
password=None,
)
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),

View File

@ -194,14 +194,14 @@ async def test_connect_backwards_compat() -> None:
@pytest.mark.asyncio
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> None:
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error(
aiohappyeyeballs_start_connection,
) -> None:
"""Verify finish_connect re-wraps exceptions as UnhandledAPIError."""
cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
asyncio.get_event_loop()
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
await cli.start_connection()
with patch.object(
@ -217,9 +217,12 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
async def test_connection_released_if_connecting_is_cancelled() -> None:
"""Verify connection is unset if connecting is cancelled."""
cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
assert cli._connection is not None
@ -229,9 +232,9 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
await start_task
assert cli._connection is None
with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
):
await cli.start_connection()
await asyncio.sleep(0)
@ -252,8 +255,9 @@ async def test_request_while_handshaking(event_loop) -> None:
pass
cli = PatchableApiClient("host", 1234, None)
with patch.object(
event_loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
), patch.object(cli, "finish_connection"):
connect_task = asyncio.create_task(cli.connect())

View File

@ -241,14 +241,15 @@ async def test_start_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection timing out."""
loop = asyncio.get_event_loop()
asyncio.get_event_loop()
async def _mock_socket_connect(*args, **kwargs):
await asyncio.sleep(500)
with patch.object(loop, "sock_connect", side_effect=_mock_socket_connect), patch(
"aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0
):
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=_mock_socket_connect,
), patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
@ -267,9 +268,12 @@ async def test_start_connection_os_error(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection has an OSError."""
loop = asyncio.get_event_loop()
asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=OSError("Socket error")):
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=OSError("Socket error"),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Socket error"):
@ -284,9 +288,12 @@ async def test_start_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection is cancelled."""
loop = asyncio.get_event_loop()
asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=asyncio.CancelledError):
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=asyncio.CancelledError,
):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Starting connection cancelled"):
@ -551,7 +558,7 @@ async def test_force_disconnect_fails(
@pytest.mark.asyncio
async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
@ -559,7 +566,7 @@ async def test_connect_resolver_times_out(
with patch(
"aioesphomeapi.host_resolver.async_resolve_host",
side_effect=asyncio.TimeoutError,
), patch.object(event_loop, "sock_connect"), patch.object(
), patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -575,6 +582,7 @@ async def test_disconnect_fails_to_send_response(
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()
transport = MagicMock()
@ -590,7 +598,7 @@ async def test_disconnect_fails_to_send_response(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -625,6 +633,7 @@ async def test_disconnect_success_case(
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()
transport = MagicMock()
@ -640,7 +649,7 @@ async def test_disconnect_success_case(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),

View File

@ -39,9 +39,9 @@ def addr_infos():
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ipv6,
info.ip_addresses_by_version.side_effect = [
[ip_address(b"\n\x00\x00*")],
[ipv6],
]
info.async_request = AsyncMock(return_value=True)
with patch(
@ -59,9 +59,9 @@ async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
zeroconf_manager = ZeroconfManager()
info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ipv6,
info.ip_addresses_by_version.side_effect = [
[ip_address(b"\n\x00\x00*")],
[ipv6],
]
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
@ -144,7 +144,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called()
assert ret == addr_infos[0]
assert ret == addr_infos
@pytest.mark.asyncio
@ -157,7 +157,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052)
assert ret == addr_infos[0]
assert ret == addr_infos
@pytest.mark.asyncio
@ -178,7 +178,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
assert ret == addr_infos[0]
assert ret == addr_infos
@pytest.mark.asyncio
@ -203,12 +203,14 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc):
resolve_zc.assert_not_called()
resolve_addr.assert_not_called()
assert ret == hr.AddrInfo(
family=socket.AddressFamily.AF_INET,
type=socket.SocketKind.SOCK_STREAM,
proto=6,
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
)
assert ret == [
hr.AddrInfo(
family=socket.AddressFamily.AF_INET,
type=socket.SocketKind.SOCK_STREAM,
proto=6,
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
)
]
@pytest.mark.asyncio

View File

@ -30,7 +30,11 @@ from .common import (
@pytest.mark.asyncio
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
async def test_log_runner(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
aiohappyeyeballs_start_connection,
):
"""Test the log runner logic."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
@ -69,7 +73,7 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
@ -96,6 +100,7 @@ async def test_log_runner_reconnects_on_disconnect(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection,
) -> None:
"""Test the log runner reconnects on disconnect."""
loop = asyncio.get_event_loop()
@ -135,7 +140,7 @@ async def test_log_runner_reconnects_on_disconnect(
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
@ -173,6 +178,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection,
) -> None:
"""Test the log runner reconnects on subscribe failure."""
loop = asyncio.get_event_loop()
@ -214,7 +220,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
with patch.object(
cli, "disconnect", partial(cli.disconnect, force=True)
), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli):
with patch.object(loop, "sock_connect"), patch.object(
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
@ -227,7 +233,7 @@ async def test_log_runner_reconnects_on_subscribe_failure(
assert cli._connection is None
with patch.object(loop, "sock_connect"), patch.object(
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs"):
connected.clear()

View File

@ -672,7 +672,9 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(
@pytest.mark.asyncio
async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventLoop):
async def test_handling_unexpected_disconnect(
event_loop: asyncio.AbstractEventLoop, aiohappyeyeballs_start_connection
):
"""Test the disconnect callback fires with expected_disconnect=False."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
@ -710,7 +712,7 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -726,7 +728,7 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
assert cli._connection.is_connected is True
await asyncio.sleep(0)
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
@ -746,7 +748,9 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
@pytest.mark.asyncio
async def test_backoff_on_encryption_error(
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
event_loop: asyncio.AbstractEventLoop,
caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection,
) -> None:
"""Test we backoff on encryption error."""
loop = asyncio.get_event_loop()
@ -785,7 +789,7 @@ async def test_backoff_on_encryption_error(
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),