mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-25 17:17:42 +01:00
Ensure scope_id is preserved from zeroconf resolution on python versions that support it (#664)
This commit is contained in:
parent
00a6ce9f6a
commit
df0dbadae7
@ -5,8 +5,8 @@ import contextlib
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import cast
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from zeroconf import IPVersion
|
||||
from zeroconf.asyncio import AsyncServiceInfo
|
||||
@ -56,7 +56,7 @@ async def _async_zeroconf_get_service_info(
|
||||
service_name: str,
|
||||
server: str,
|
||||
timeout: float,
|
||||
) -> AsyncServiceInfo | None:
|
||||
) -> AsyncServiceInfo:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
try:
|
||||
zc = zeroconf_manager.get_async_zeroconf().zeroconf
|
||||
@ -67,8 +67,7 @@ async def _async_zeroconf_get_service_info(
|
||||
) from exc
|
||||
try:
|
||||
info = AsyncServiceInfo(service_type, service_name, server=server)
|
||||
if await info.async_request(zc, int(timeout * 1000)):
|
||||
return info
|
||||
await info.async_request(zc, int(timeout * 1000))
|
||||
except Exception as exc:
|
||||
raise ResolveAPIError(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
@ -78,6 +77,16 @@ async def _async_zeroconf_get_service_info(
|
||||
return info
|
||||
|
||||
|
||||
def _scope_id_to_int(value: str | None) -> int:
|
||||
"""Convert a scope id to int if possible."""
|
||||
if value is None:
|
||||
return 0
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
|
||||
async def _async_resolve_host_zeroconf(
|
||||
host: str,
|
||||
port: int,
|
||||
@ -96,35 +105,9 @@ async def _async_resolve_host_zeroconf(
|
||||
server,
|
||||
timeout,
|
||||
)
|
||||
|
||||
if info is None:
|
||||
return []
|
||||
|
||||
addrs: list[AddrInfo] = []
|
||||
for ip_address in info.ip_addresses_by_version(IPVersion.All):
|
||||
is_ipv6 = ip_address.version == 6
|
||||
sockaddr: IPv6Sockaddr | IPv4Sockaddr
|
||||
if is_ipv6:
|
||||
sockaddr = IPv6Sockaddr(
|
||||
address=str(ip_address),
|
||||
port=port,
|
||||
flowinfo=0,
|
||||
scope_id=0,
|
||||
)
|
||||
else:
|
||||
sockaddr = IPv4Sockaddr(
|
||||
address=str(ip_address),
|
||||
port=port,
|
||||
)
|
||||
|
||||
addrs.append(
|
||||
AddrInfo(
|
||||
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=sockaddr,
|
||||
)
|
||||
)
|
||||
for ip in info.ip_addresses_by_version(IPVersion.All):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
|
||||
return addrs
|
||||
|
||||
|
||||
@ -160,34 +143,37 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
|
||||
return addrs
|
||||
|
||||
|
||||
def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
|
||||
def _async_ip_address_to_addrs(
|
||||
ip: IPv4Address | IPv6Address, port: int
|
||||
) -> list[AddrInfo]:
|
||||
"""Convert an ipaddress to AddrInfo."""
|
||||
with contextlib.suppress(ValueError):
|
||||
return [
|
||||
AddrInfo(
|
||||
family=socket.AF_INET6,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv6Sockaddr(
|
||||
address=str(IPv6Address(host)), port=port, flowinfo=0, scope_id=0
|
||||
),
|
||||
)
|
||||
]
|
||||
addrs: list[AddrInfo] = []
|
||||
is_ipv6 = ip.version == 6
|
||||
sockaddr: IPv6Sockaddr | IPv4Sockaddr
|
||||
if is_ipv6:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(ip, IPv6Address)
|
||||
sockaddr = IPv6Sockaddr(
|
||||
address=str(ip).partition("%")[0],
|
||||
port=port,
|
||||
flowinfo=0,
|
||||
scope_id=_scope_id_to_int(ip.scope_id),
|
||||
)
|
||||
else:
|
||||
sockaddr = IPv4Sockaddr(
|
||||
address=str(ip),
|
||||
port=port,
|
||||
)
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
return [
|
||||
AddrInfo(
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv4Sockaddr(
|
||||
address=str(IPv4Address(host)),
|
||||
port=port,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
addrs.append(
|
||||
AddrInfo(
|
||||
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=sockaddr,
|
||||
)
|
||||
)
|
||||
return addrs
|
||||
|
||||
|
||||
async def async_resolve_host(
|
||||
@ -206,11 +192,12 @@ async def async_resolve_host(
|
||||
name, port, zeroconf_manager=zeroconf_manager
|
||||
)
|
||||
)
|
||||
except APIConnectionError as err:
|
||||
except ResolveAPIError as err:
|
||||
zc_error = err
|
||||
|
||||
else:
|
||||
addrs.extend(_async_ip_address_to_addrs(host, port))
|
||||
with contextlib.suppress(ValueError):
|
||||
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||
|
||||
if not addrs:
|
||||
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from ipaddress import ip_address
|
||||
from ipaddress import IPv6Address, ip_address
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -38,9 +38,10 @@ def addr_infos():
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
ipv6,
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
@ -57,9 +58,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
|
||||
zeroconf_manager = ZeroconfManager()
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
ipv6,
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
@ -81,6 +83,25 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
|
||||
assert ret == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf):
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||
side_effect=Exception("no buffers"),
|
||||
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||
await hr._async_resolve_host_zeroconf("asdf.local", 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[])
|
||||
async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf):
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||
side_effect=Exception("no buffers"),
|
||||
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||
await hr.async_resolve_host("asdf.local", 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
|
||||
with patch.object(event_loop, "getaddrinfo") as mock:
|
||||
@ -139,6 +160,15 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
assert ret == addr_infos[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", return_value=False)
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
||||
resolve_addr.return_value = addr_infos
|
||||
with pytest.raises(ResolveAPIError):
|
||||
await hr.async_resolve_host("example.local", 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf")
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
@ -216,3 +246,9 @@ async def test_resolve_host_create_zeroconf_oserror(
|
||||
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
|
||||
), pytest.raises(ResolveAPIError, match="out of buffers"):
|
||||
await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
||||
|
||||
def test_scope_id_to_int():
|
||||
assert hr._scope_id_to_int("123") == 123
|
||||
assert hr._scope_id_to_int("eth0") == 0
|
||||
assert hr._scope_id_to_int(None) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user