mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-26 17:27:40 +01:00
Ensure scope_id is preserved from zeroconf resolution on python versions that support it (#664)
This commit is contained in:
parent
00a6ce9f6a
commit
df0dbadae7
@ -5,8 +5,8 @@ import contextlib
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ipaddress import IPv4Address, IPv6Address
|
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||||
from typing import cast
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
from zeroconf import IPVersion
|
from zeroconf import IPVersion
|
||||||
from zeroconf.asyncio import AsyncServiceInfo
|
from zeroconf.asyncio import AsyncServiceInfo
|
||||||
@ -56,7 +56,7 @@ async def _async_zeroconf_get_service_info(
|
|||||||
service_name: str,
|
service_name: str,
|
||||||
server: str,
|
server: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> AsyncServiceInfo | None:
|
) -> AsyncServiceInfo:
|
||||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||||
try:
|
try:
|
||||||
zc = zeroconf_manager.get_async_zeroconf().zeroconf
|
zc = zeroconf_manager.get_async_zeroconf().zeroconf
|
||||||
@ -67,8 +67,7 @@ async def _async_zeroconf_get_service_info(
|
|||||||
) from exc
|
) from exc
|
||||||
try:
|
try:
|
||||||
info = AsyncServiceInfo(service_type, service_name, server=server)
|
info = AsyncServiceInfo(service_type, service_name, server=server)
|
||||||
if await info.async_request(zc, int(timeout * 1000)):
|
await info.async_request(zc, int(timeout * 1000))
|
||||||
return info
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise ResolveAPIError(
|
raise ResolveAPIError(
|
||||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||||
@ -78,6 +77,16 @@ async def _async_zeroconf_get_service_info(
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def _scope_id_to_int(value: str | None) -> int:
|
||||||
|
"""Convert a scope id to int if possible."""
|
||||||
|
if value is None:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
async def _async_resolve_host_zeroconf(
|
async def _async_resolve_host_zeroconf(
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
@ -96,35 +105,9 @@ async def _async_resolve_host_zeroconf(
|
|||||||
server,
|
server,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if info is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
addrs: list[AddrInfo] = []
|
addrs: list[AddrInfo] = []
|
||||||
for ip_address in info.ip_addresses_by_version(IPVersion.All):
|
for ip in info.ip_addresses_by_version(IPVersion.All):
|
||||||
is_ipv6 = ip_address.version == 6
|
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
|
||||||
sockaddr: IPv6Sockaddr | IPv4Sockaddr
|
|
||||||
if is_ipv6:
|
|
||||||
sockaddr = IPv6Sockaddr(
|
|
||||||
address=str(ip_address),
|
|
||||||
port=port,
|
|
||||||
flowinfo=0,
|
|
||||||
scope_id=0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sockaddr = IPv4Sockaddr(
|
|
||||||
address=str(ip_address),
|
|
||||||
port=port,
|
|
||||||
)
|
|
||||||
|
|
||||||
addrs.append(
|
|
||||||
AddrInfo(
|
|
||||||
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
|
|
||||||
type=socket.SOCK_STREAM,
|
|
||||||
proto=socket.IPPROTO_TCP,
|
|
||||||
sockaddr=sockaddr,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return addrs
|
return addrs
|
||||||
|
|
||||||
|
|
||||||
@ -160,34 +143,37 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
|
|||||||
return addrs
|
return addrs
|
||||||
|
|
||||||
|
|
||||||
def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
|
def _async_ip_address_to_addrs(
|
||||||
|
ip: IPv4Address | IPv6Address, port: int
|
||||||
|
) -> list[AddrInfo]:
|
||||||
"""Convert an ipaddress to AddrInfo."""
|
"""Convert an ipaddress to AddrInfo."""
|
||||||
with contextlib.suppress(ValueError):
|
addrs: list[AddrInfo] = []
|
||||||
return [
|
is_ipv6 = ip.version == 6
|
||||||
AddrInfo(
|
sockaddr: IPv6Sockaddr | IPv4Sockaddr
|
||||||
family=socket.AF_INET6,
|
if is_ipv6:
|
||||||
type=socket.SOCK_STREAM,
|
if TYPE_CHECKING:
|
||||||
proto=socket.IPPROTO_TCP,
|
assert isinstance(ip, IPv6Address)
|
||||||
sockaddr = IPv6Sockaddr(
|
sockaddr = IPv6Sockaddr(
|
||||||
address=str(IPv6Address(host)), port=port, flowinfo=0, scope_id=0
|
address=str(ip).partition("%")[0],
|
||||||
),
|
port=port,
|
||||||
|
flowinfo=0,
|
||||||
|
scope_id=_scope_id_to_int(ip.scope_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sockaddr = IPv4Sockaddr(
|
||||||
|
address=str(ip),
|
||||||
|
port=port,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
with contextlib.suppress(ValueError):
|
addrs.append(
|
||||||
return [
|
|
||||||
AddrInfo(
|
AddrInfo(
|
||||||
family=socket.AF_INET,
|
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
|
||||||
type=socket.SOCK_STREAM,
|
type=socket.SOCK_STREAM,
|
||||||
proto=socket.IPPROTO_TCP,
|
proto=socket.IPPROTO_TCP,
|
||||||
sockaddr=IPv4Sockaddr(
|
sockaddr=sockaddr,
|
||||||
address=str(IPv4Address(host)),
|
|
||||||
port=port,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
|
return addrs
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def async_resolve_host(
|
async def async_resolve_host(
|
||||||
@ -206,11 +192,12 @@ async def async_resolve_host(
|
|||||||
name, port, zeroconf_manager=zeroconf_manager
|
name, port, zeroconf_manager=zeroconf_manager
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except APIConnectionError as err:
|
except ResolveAPIError as err:
|
||||||
zc_error = err
|
zc_error = err
|
||||||
|
|
||||||
else:
|
else:
|
||||||
addrs.extend(_async_ip_address_to_addrs(host, port))
|
with contextlib.suppress(ValueError):
|
||||||
|
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
|
||||||
|
|
||||||
if not addrs:
|
if not addrs:
|
||||||
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import ip_address
|
from ipaddress import IPv6Address, ip_address
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -38,9 +38,10 @@ def addr_infos():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||||
|
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||||
info.ip_addresses_by_version.return_value = [
|
info.ip_addresses_by_version.return_value = [
|
||||||
ip_address(b"\n\x00\x00*"),
|
ip_address(b"\n\x00\x00*"),
|
||||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
ipv6,
|
||||||
]
|
]
|
||||||
info.async_request = AsyncMock(return_value=True)
|
info.async_request = AsyncMock(return_value=True)
|
||||||
with patch(
|
with patch(
|
||||||
@ -57,9 +58,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
|||||||
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
|
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
|
||||||
zeroconf_manager = ZeroconfManager()
|
zeroconf_manager = ZeroconfManager()
|
||||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||||
|
ipv6 = IPv6Address("2001:db8:85a3::8a2e:370:7334%0")
|
||||||
info.ip_addresses_by_version.return_value = [
|
info.ip_addresses_by_version.return_value = [
|
||||||
ip_address(b"\n\x00\x00*"),
|
ip_address(b"\n\x00\x00*"),
|
||||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
ipv6,
|
||||||
]
|
]
|
||||||
info.async_request = AsyncMock(return_value=True)
|
info.async_request = AsyncMock(return_value=True)
|
||||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||||
@ -81,6 +83,25 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
|
|||||||
assert ret == []
|
assert ret == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf):
|
||||||
|
with patch(
|
||||||
|
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||||
|
side_effect=Exception("no buffers"),
|
||||||
|
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||||
|
await hr._async_resolve_host_zeroconf("asdf.local", 6052)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[])
|
||||||
|
async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf):
|
||||||
|
with patch(
|
||||||
|
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||||
|
side_effect=Exception("no buffers"),
|
||||||
|
), pytest.raises(ResolveAPIError, match="no buffers"):
|
||||||
|
await hr.async_resolve_host("asdf.local", 6052)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
|
async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
|
||||||
with patch.object(event_loop, "getaddrinfo") as mock:
|
with patch.object(event_loop, "getaddrinfo") as mock:
|
||||||
@ -139,6 +160,15 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
|||||||
assert ret == addr_infos[0]
|
assert ret == addr_infos[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", return_value=False)
|
||||||
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||||
|
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
||||||
|
resolve_addr.return_value = addr_infos
|
||||||
|
with pytest.raises(ResolveAPIError):
|
||||||
|
await hr.async_resolve_host("example.local", 6052)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf")
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_zeroconf")
|
||||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||||
@ -216,3 +246,9 @@ async def test_resolve_host_create_zeroconf_oserror(
|
|||||||
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
|
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
|
||||||
), pytest.raises(ResolveAPIError, match="out of buffers"):
|
), pytest.raises(ResolveAPIError, match="out of buffers"):
|
||||||
await hr._async_resolve_host_zeroconf("asdf", 6052)
|
await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scope_id_to_int():
|
||||||
|
assert hr._scope_id_to_int("123") == 123
|
||||||
|
assert hr._scope_id_to_int("eth0") == 0
|
||||||
|
assert hr._scope_id_to_int(None) == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user