Improve host resolver coverage (#583)

This commit is contained in:
J. Nick Koston 2023-10-15 13:05:23 -10:00 committed by GitHub
parent 2a78e9588e
commit cb5cea784e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 34 deletions

View File

@ -7,12 +7,12 @@ from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from typing import Union, cast from typing import Union, cast
import zeroconf from zeroconf import IPVersion, Zeroconf
import zeroconf.asyncio from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
from .core import APIConnectionError, ResolveAPIError from .core import APIConnectionError, ResolveAPIError
ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None] ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
@dataclass(frozen=True) @dataclass(frozen=True)
@ -47,39 +47,38 @@ async def _async_zeroconf_get_service_info(
service_type: str, service_type: str,
service_name: str, service_name: str,
timeout: float, timeout: float,
) -> "zeroconf.asyncio.AsyncServiceInfo" | None: ) -> AsyncServiceInfo | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf # Use or create zeroconf instance, ensure it's an AsyncZeroconf
async_zc_instance: AsyncZeroconf | None = None
if zeroconf_instance is None: if zeroconf_instance is None:
try: try:
zc = zeroconf.asyncio.AsyncZeroconf() async_zc_instance = AsyncZeroconf()
except Exception: except Exception:
raise ResolveAPIError( raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without " "Cannot start mDNS sockets, is this a docker container without "
"host network mode?" "host network mode?"
) )
do_close = True zc = async_zc_instance.zeroconf
elif isinstance(zeroconf_instance, zeroconf.asyncio.AsyncZeroconf): elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance.zeroconf
elif isinstance(zeroconf_instance, Zeroconf):
zc = zeroconf_instance zc = zeroconf_instance
do_close = False
elif isinstance(zeroconf_instance, zeroconf.Zeroconf):
zc = zeroconf.asyncio.AsyncZeroconf(zc=zeroconf_instance)
do_close = False
else: else:
raise ValueError( raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
) )
try: try:
info = await zc.async_get_service_info( info = AsyncServiceInfo(service_type, service_name)
service_type, service_name, int(timeout * 1000) if await info.async_request(zc, int(timeout * 1000)):
) return info
except Exception as exc: except Exception as exc:
raise ResolveAPIError( raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}" f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc ) from exc
finally: finally:
if do_close: if async_zc_instance:
await zc.async_close() await async_zc_instance.async_close()
return info return info
@ -101,7 +100,7 @@ async def _async_resolve_host_zeroconf(
return [] return []
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(zeroconf.IPVersion.All): for ip_address in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6 is_ipv6 = ip_address.version == 6
sockaddr: Sockaddr sockaddr: Sockaddr
if is_ipv6: if is_ipv6:

View File

@ -1,17 +1,21 @@
import socket import socket
from ipaddress import ip_address from ipaddress import ip_address
import asyncio
import pytest import pytest
from mock import AsyncMock, MagicMock, patch from mock import AsyncMock, MagicMock, patch
from zeroconf import DNSCache
from zeroconf.asyncio import AsyncZeroconf, AsyncServiceInfo
import aioesphomeapi.host_resolver as hr import aioesphomeapi.host_resolver as hr
from aioesphomeapi.core import APIConnectionError from aioesphomeapi.core import APIConnectionError
@pytest.fixture @pytest.fixture
def async_zeroconf(): def async_zeroconf():
with patch("zeroconf.asyncio.AsyncZeroconf") as klass: with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
yield klass.return_value async_zeroconf = klass.return_value
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
yield async_zeroconf
@pytest.fixture @pytest.fixture
@ -38,31 +42,49 @@ def addr_infos():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_zeroconf(async_zeroconf, addr_infos): async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
info = MagicMock() info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [ info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"), ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
] ]
async_zeroconf.async_get_service_info = AsyncMock(return_value=info) info.async_request = AsyncMock(return_value=True)
async_zeroconf.async_close = AsyncMock() with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
ret = await hr._async_resolve_host_zeroconf("asdf", 6052) info.async_request.assert_called_once()
async_zeroconf.async_get_service_info.assert_called_once_with(
"_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000
)
async_zeroconf.async_close.assert_called_once_with() async_zeroconf.async_close.assert_called_once_with()
assert ret == addr_infos assert ret == addr_infos
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_zeroconf_empty(async_zeroconf): async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
async_zeroconf.async_get_service_info = AsyncMock(return_value=None) async_zeroconf = AsyncZeroconf(zc=MagicMock())
async_zeroconf.async_close = AsyncMock() async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf(
"asdf", 6052, zeroconf_instance=async_zeroconf
)
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052) info.async_request.assert_called_once()
async_zeroconf.async_close.assert_not_called()
assert ret == addr_infos
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request"
) as mock_async_request:
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
assert mock_async_request.call_count == 1
assert ret == [] assert ret == []