From df0dbadae7bd9942c9e3d67042770539f183e2a4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Nov 2023 14:48:34 +0100 Subject: [PATCH] Ensure scope_id is preserved from zeroconf resolution on python versions that support it (#664) --- aioesphomeapi/host_resolver.py | 109 +++++++++++++++------------------ tests/test_host_resolver.py | 42 ++++++++++++- 2 files changed, 87 insertions(+), 64 deletions(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index d51bb16..679caf6 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -5,8 +5,8 @@ import contextlib import logging import socket from dataclasses import dataclass -from ipaddress import IPv4Address, IPv6Address -from typing import cast +from ipaddress import IPv4Address, IPv6Address, ip_address +from typing import TYPE_CHECKING, cast from zeroconf import IPVersion from zeroconf.asyncio import AsyncServiceInfo @@ -56,7 +56,7 @@ async def _async_zeroconf_get_service_info( service_name: str, server: str, timeout: float, -) -> AsyncServiceInfo | None: +) -> AsyncServiceInfo: # Use or create zeroconf instance, ensure it's an AsyncZeroconf try: zc = zeroconf_manager.get_async_zeroconf().zeroconf @@ -67,8 +67,7 @@ async def _async_zeroconf_get_service_info( ) from exc try: info = AsyncServiceInfo(service_type, service_name, server=server) - if await info.async_request(zc, int(timeout * 1000)): - return info + await info.async_request(zc, int(timeout * 1000)) except Exception as exc: raise ResolveAPIError( f"Error resolving mDNS {service_name} via mDNS: {exc}" @@ -78,6 +77,16 @@ async def _async_zeroconf_get_service_info( return info +def _scope_id_to_int(value: str | None) -> int: + """Convert a scope id to int if possible.""" + if value is None: + return 0 + try: + return int(value) + except ValueError: + return 0 + + async def _async_resolve_host_zeroconf( host: str, port: int, @@ -96,35 +105,9 @@ async def _async_resolve_host_zeroconf( server, timeout, ) - - if info is None: - return [] - addrs: list[AddrInfo] = [] - for ip_address in info.ip_addresses_by_version(IPVersion.All): - is_ipv6 = ip_address.version == 6 - sockaddr: IPv6Sockaddr | IPv4Sockaddr - if is_ipv6: - sockaddr = IPv6Sockaddr( - address=str(ip_address), - port=port, - flowinfo=0, - scope_id=0, - ) - else: - sockaddr = IPv4Sockaddr( - address=str(ip_address), - port=port, - ) - - addrs.append( - AddrInfo( - family=socket.AF_INET6 if is_ipv6 else socket.AF_INET, - type=socket.SOCK_STREAM, - proto=socket.IPPROTO_TCP, - sockaddr=sockaddr, - ) - ) + for ip in info.ip_addresses_by_version(IPVersion.All): + addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type] return addrs @@ -160,34 +143,37 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo return addrs -def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]: +def _async_ip_address_to_addrs( + ip: IPv4Address | IPv6Address, port: int +) -> list[AddrInfo]: """Convert an ipaddress to AddrInfo.""" - with contextlib.suppress(ValueError): - return [ - AddrInfo( - family=socket.AF_INET6, - type=socket.SOCK_STREAM, - proto=socket.IPPROTO_TCP, - sockaddr=IPv6Sockaddr( - address=str(IPv6Address(host)), port=port, flowinfo=0, scope_id=0 - ), - ) - ] + addrs: list[AddrInfo] = [] + is_ipv6 = ip.version == 6 + sockaddr: IPv6Sockaddr | IPv4Sockaddr + if is_ipv6: + if TYPE_CHECKING: + assert isinstance(ip, IPv6Address) + sockaddr = IPv6Sockaddr( + address=str(ip).partition("%")[0], + port=port, + flowinfo=0, + scope_id=_scope_id_to_int(ip.scope_id), + ) + else: + sockaddr = IPv4Sockaddr( + address=str(ip), + port=port, + ) - with contextlib.suppress(ValueError): - return [ - AddrInfo( - family=socket.AF_INET, - type=socket.SOCK_STREAM, - proto=socket.IPPROTO_TCP, - sockaddr=IPv4Sockaddr( - address=str(IPv4Address(host)), - port=port, - ), - ) - ] - - return [] + addrs.append( + AddrInfo( + family=socket.AF_INET6 if is_ipv6 else socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=sockaddr, + ) + ) + return addrs async def async_resolve_host( @@ -206,11 +192,12 @@ async def async_resolve_host( name, port, zeroconf_manager=zeroconf_manager ) ) - except APIConnectionError as err: + except ResolveAPIError as err: zc_error = err else: - addrs.extend(_async_ip_address_to_addrs(host, port)) + with contextlib.suppress(ValueError): + addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) if not addrs: addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index d2b9287..def1267 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -1,7 +1,7 @@ from __future__ import annotations import socket -from ipaddress import ip_address +from ipaddress import IPv6Address, ip_address from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -38,9 +38,10 @@ def addr_infos(): @pytest.mark.asyncio async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): info = MagicMock(auto_spec=AsyncServiceInfo) + ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0") info.ip_addresses_by_version.return_value = [ ip_address(b"\n\x00\x00*"), - ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), + ipv6, ] info.async_request = AsyncMock(return_value=True) with patch( @@ -57,9 +58,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf): zeroconf_manager = ZeroconfManager() info = MagicMock(auto_spec=AsyncServiceInfo) + ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0") info.ip_addresses_by_version.return_value = [ ip_address(b"\n\x00\x00*"), - ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), + ipv6, ] info.async_request = AsyncMock(return_value=True) with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): @@ -81,6 +83,25 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf): assert ret == [] +@pytest.mark.asyncio +async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf): + with patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=Exception("no buffers"), + ), pytest.raises(ResolveAPIError, match="no buffers"): + await hr._async_resolve_host_zeroconf("asdf.local", 6052) + + +@pytest.mark.asyncio +@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[]) +async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf): + with patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=Exception("no buffers"), + ), pytest.raises(ResolveAPIError, match="no buffers"): + await hr.async_resolve_host("asdf.local", 6052) + + @pytest.mark.asyncio async def test_resolve_host_getaddrinfo(event_loop, addr_infos): with patch.object(event_loop, "getaddrinfo") as mock: @@ -139,6 +160,15 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): assert ret == addr_infos[0] +@pytest.mark.asyncio +@patch("aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", return_value=False) +@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") +async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): + resolve_addr.return_value = addr_infos + with pytest.raises(ResolveAPIError): + await hr.async_resolve_host("example.local", 6052) + + @pytest.mark.asyncio @patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf") @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") @@ -216,3 +246,9 @@ async def test_resolve_host_create_zeroconf_oserror( "aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers") ), pytest.raises(ResolveAPIError, match="out of buffers"): await hr._async_resolve_host_zeroconf("asdf", 6052) + + +def test_scope_id_to_int(): + assert hr._scope_id_to_int("123") == 123 + assert hr._scope_id_to_int("eth0") == 0 + assert hr._scope_id_to_int(None) == 0