Ensure scope_id is preserved from zeroconf resolution on python versions that support it (#664)

This commit is contained in:
J. Nick Koston 2023-11-23 14:48:34 +01:00 committed by GitHub
parent 00a6ce9f6a
commit df0dbadae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 64 deletions

View File

@ -5,8 +5,8 @@ import contextlib
import logging import logging
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address, ip_address
from typing import cast from typing import TYPE_CHECKING, cast
from zeroconf import IPVersion from zeroconf import IPVersion
from zeroconf.asyncio import AsyncServiceInfo from zeroconf.asyncio import AsyncServiceInfo
@ -56,7 +56,7 @@ async def _async_zeroconf_get_service_info(
service_name: str, service_name: str,
server: str, server: str,
timeout: float, timeout: float,
) -> AsyncServiceInfo | None: ) -> AsyncServiceInfo:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf # Use or create zeroconf instance, ensure it's an AsyncZeroconf
try: try:
zc = zeroconf_manager.get_async_zeroconf().zeroconf zc = zeroconf_manager.get_async_zeroconf().zeroconf
@ -67,8 +67,7 @@ async def _async_zeroconf_get_service_info(
) from exc ) from exc
try: try:
info = AsyncServiceInfo(service_type, service_name, server=server) info = AsyncServiceInfo(service_type, service_name, server=server)
if await info.async_request(zc, int(timeout * 1000)): await info.async_request(zc, int(timeout * 1000))
return info
except Exception as exc: except Exception as exc:
raise ResolveAPIError( raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}" f"Error resolving mDNS {service_name} via mDNS: {exc}"
@ -78,6 +77,16 @@ async def _async_zeroconf_get_service_info(
return info return info
def _scope_id_to_int(value: str | None) -> int:
"""Convert a scope id to int if possible."""
if value is None:
return 0
try:
return int(value)
except ValueError:
return 0
async def _async_resolve_host_zeroconf( async def _async_resolve_host_zeroconf(
host: str, host: str,
port: int, port: int,
@ -96,35 +105,9 @@ async def _async_resolve_host_zeroconf(
server, server,
timeout, timeout,
) )
if info is None:
return []
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(IPVersion.All): for ip in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6 addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
sockaddr: IPv6Sockaddr | IPv4Sockaddr
if is_ipv6:
sockaddr = IPv6Sockaddr(
address=str(ip_address),
port=port,
flowinfo=0,
scope_id=0,
)
else:
sockaddr = IPv4Sockaddr(
address=str(ip_address),
port=port,
)
addrs.append(
AddrInfo(
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=sockaddr,
)
)
return addrs return addrs
@ -160,34 +143,37 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
return addrs return addrs
def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]: def _async_ip_address_to_addrs(
ip: IPv4Address | IPv6Address, port: int
) -> list[AddrInfo]:
"""Convert an ipaddress to AddrInfo.""" """Convert an ipaddress to AddrInfo."""
with contextlib.suppress(ValueError): addrs: list[AddrInfo] = []
return [ is_ipv6 = ip.version == 6
AddrInfo( sockaddr: IPv6Sockaddr | IPv4Sockaddr
family=socket.AF_INET6, if is_ipv6:
type=socket.SOCK_STREAM, if TYPE_CHECKING:
proto=socket.IPPROTO_TCP, assert isinstance(ip, IPv6Address)
sockaddr = IPv6Sockaddr( sockaddr = IPv6Sockaddr(
address=str(IPv6Address(host)), port=port, flowinfo=0, scope_id=0 address=str(ip).partition("%")[0],
), port=port,
flowinfo=0,
scope_id=_scope_id_to_int(ip.scope_id),
)
else:
sockaddr = IPv4Sockaddr(
address=str(ip),
port=port,
) )
]
with contextlib.suppress(ValueError): addrs.append(
return [
AddrInfo( AddrInfo(
family=socket.AF_INET, family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM, type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP, proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr( sockaddr=sockaddr,
address=str(IPv4Address(host)),
port=port,
),
) )
] )
return addrs
return []
async def async_resolve_host( async def async_resolve_host(
@ -206,11 +192,12 @@ async def async_resolve_host(
name, port, zeroconf_manager=zeroconf_manager name, port, zeroconf_manager=zeroconf_manager
) )
) )
except APIConnectionError as err: except ResolveAPIError as err:
zc_error = err zc_error = err
else: else:
addrs.extend(_async_ip_address_to_addrs(host, port)) with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not addrs: if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) addrs.extend(await _async_resolve_host_getaddrinfo(host, port))

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import socket import socket
from ipaddress import ip_address from ipaddress import IPv6Address, ip_address
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -38,9 +38,10 @@ def addr_infos():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
info = MagicMock(auto_spec=AsyncServiceInfo) info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [ info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"), ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ipv6,
] ]
info.async_request = AsyncMock(return_value=True) info.async_request = AsyncMock(return_value=True)
with patch( with patch(
@ -57,9 +58,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf): async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
zeroconf_manager = ZeroconfManager() zeroconf_manager = ZeroconfManager()
info = MagicMock(auto_spec=AsyncServiceInfo) info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [ info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"), ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ipv6,
] ]
info.async_request = AsyncMock(return_value=True) info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
@ -81,6 +83,25 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
assert ret == [] assert ret == []
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr._async_resolve_host_zeroconf("asdf.local", 6052)
@pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[])
async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr.async_resolve_host("asdf.local", 6052)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_getaddrinfo(event_loop, addr_infos): async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
with patch.object(event_loop, "getaddrinfo") as mock: with patch.object(event_loop, "getaddrinfo") as mock:
@ -139,6 +160,15 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
assert ret == addr_infos[0] assert ret == addr_infos[0]
@pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", return_value=False)
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
resolve_addr.return_value = addr_infos
with pytest.raises(ResolveAPIError):
await hr.async_resolve_host("example.local", 6052)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf") @patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf")
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
@ -216,3 +246,9 @@ async def test_resolve_host_create_zeroconf_oserror(
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers") "aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
), pytest.raises(ResolveAPIError, match="out of buffers"): ), pytest.raises(ResolveAPIError, match="out of buffers"):
await hr._async_resolve_host_zeroconf("asdf", 6052) await hr._async_resolve_host_zeroconf("asdf", 6052)
def test_scope_id_to_int():
assert hr._scope_id_to_int("123") == 123
assert hr._scope_id_to_int("eth0") == 0
assert hr._scope_id_to_int(None) == 0