Ensure scope_id is preserved from zeroconf resolution on python versions that support it (#664)

This commit is contained in:
J. Nick Koston 2023-11-23 14:48:34 +01:00 committed by GitHub
parent 00a6ce9f6a
commit df0dbadae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 64 deletions

View File

@ -5,8 +5,8 @@ import contextlib
import logging
import socket
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from typing import cast
from ipaddress import IPv4Address, IPv6Address, ip_address
from typing import TYPE_CHECKING, cast
from zeroconf import IPVersion
from zeroconf.asyncio import AsyncServiceInfo
@ -56,7 +56,7 @@ async def _async_zeroconf_get_service_info(
service_name: str,
server: str,
timeout: float,
) -> AsyncServiceInfo | None:
) -> AsyncServiceInfo:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
try:
zc = zeroconf_manager.get_async_zeroconf().zeroconf
@ -67,8 +67,7 @@ async def _async_zeroconf_get_service_info(
) from exc
try:
info = AsyncServiceInfo(service_type, service_name, server=server)
if await info.async_request(zc, int(timeout * 1000)):
return info
await info.async_request(zc, int(timeout * 1000))
except Exception as exc:
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
@ -78,6 +77,16 @@ async def _async_zeroconf_get_service_info(
return info
def _scope_id_to_int(value: str | None) -> int:
"""Convert a scope id to int if possible."""
if value is None:
return 0
try:
return int(value)
except ValueError:
return 0
async def _async_resolve_host_zeroconf(
host: str,
port: int,
@ -96,35 +105,9 @@ async def _async_resolve_host_zeroconf(
server,
timeout,
)
if info is None:
return []
addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6
sockaddr: IPv6Sockaddr | IPv4Sockaddr
if is_ipv6:
sockaddr = IPv6Sockaddr(
address=str(ip_address),
port=port,
flowinfo=0,
scope_id=0,
)
else:
sockaddr = IPv4Sockaddr(
address=str(ip_address),
port=port,
)
addrs.append(
AddrInfo(
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=sockaddr,
)
)
for ip in info.ip_addresses_by_version(IPVersion.All):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
return addrs
@ -160,34 +143,37 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
return addrs
def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
def _async_ip_address_to_addrs(
ip: IPv4Address | IPv6Address, port: int
) -> list[AddrInfo]:
"""Convert an ipaddress to AddrInfo."""
with contextlib.suppress(ValueError):
return [
AddrInfo(
family=socket.AF_INET6,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv6Sockaddr(
address=str(IPv6Address(host)), port=port, flowinfo=0, scope_id=0
),
)
]
addrs: list[AddrInfo] = []
is_ipv6 = ip.version == 6
sockaddr: IPv6Sockaddr | IPv4Sockaddr
if is_ipv6:
if TYPE_CHECKING:
assert isinstance(ip, IPv6Address)
sockaddr = IPv6Sockaddr(
address=str(ip).partition("%")[0],
port=port,
flowinfo=0,
scope_id=_scope_id_to_int(ip.scope_id),
)
else:
sockaddr = IPv4Sockaddr(
address=str(ip),
port=port,
)
with contextlib.suppress(ValueError):
return [
AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr(
address=str(IPv4Address(host)),
port=port,
),
)
]
return []
addrs.append(
AddrInfo(
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=sockaddr,
)
)
return addrs
async def async_resolve_host(
@ -206,11 +192,12 @@ async def async_resolve_host(
name, port, zeroconf_manager=zeroconf_manager
)
)
except APIConnectionError as err:
except ResolveAPIError as err:
zc_error = err
else:
addrs.extend(_async_ip_address_to_addrs(host, port))
with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import socket
from ipaddress import ip_address
from ipaddress import IPv6Address, ip_address
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -38,9 +38,10 @@ def addr_infos():
@pytest.mark.asyncio
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
ipv6,
]
info.async_request = AsyncMock(return_value=True)
with patch(
@ -57,9 +58,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
zeroconf_manager = ZeroconfManager()
info = MagicMock(auto_spec=AsyncServiceInfo)
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
ipv6,
]
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
@ -81,6 +83,25 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
assert ret == []
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr._async_resolve_host_zeroconf("asdf.local", 6052)
@pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[])
async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr.async_resolve_host("asdf.local", 6052)
@pytest.mark.asyncio
async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
with patch.object(event_loop, "getaddrinfo") as mock:
@ -139,6 +160,15 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
assert ret == addr_infos[0]
@pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", return_value=False)
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
resolve_addr.return_value = addr_infos
with pytest.raises(ResolveAPIError):
await hr.async_resolve_host("example.local", 6052)
@pytest.mark.asyncio
@patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf")
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
@ -216,3 +246,9 @@ async def test_resolve_host_create_zeroconf_oserror(
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
), pytest.raises(ResolveAPIError, match="out of buffers"):
await hr._async_resolve_host_zeroconf("asdf", 6052)
def test_scope_id_to_int():
assert hr._scope_id_to_int("123") == 123
assert hr._scope_id_to_int("eth0") == 0
assert hr._scope_id_to_int(None) == 0