Improve host resolver coverage (#583)

This commit is contained in:
J. Nick Koston 2023-10-15 13:05:23 -10:00 committed by GitHub
parent 2a78e9588e
commit cb5cea784e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 34 deletions

View File

@ -7,12 +7,12 @@ from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from typing import Union, cast
import zeroconf
import zeroconf.asyncio
from zeroconf import IPVersion, Zeroconf
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
from .core import APIConnectionError, ResolveAPIError
ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None]
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
@dataclass(frozen=True)
@ -47,39 +47,38 @@ async def _async_zeroconf_get_service_info(
service_type: str,
service_name: str,
timeout: float,
) -> "zeroconf.asyncio.AsyncServiceInfo" | None:
) -> AsyncServiceInfo | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
async_zc_instance: AsyncZeroconf | None = None
if zeroconf_instance is None:
try:
zc = zeroconf.asyncio.AsyncZeroconf()
async_zc_instance = AsyncZeroconf()
except Exception:
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
do_close = True
elif isinstance(zeroconf_instance, zeroconf.asyncio.AsyncZeroconf):
zc = async_zc_instance.zeroconf
elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance.zeroconf
elif isinstance(zeroconf_instance, Zeroconf):
zc = zeroconf_instance
do_close = False
elif isinstance(zeroconf_instance, zeroconf.Zeroconf):
zc = zeroconf.asyncio.AsyncZeroconf(zc=zeroconf_instance)
do_close = False
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
try:
info = await zc.async_get_service_info(
service_type, service_name, int(timeout * 1000)
)
info = AsyncServiceInfo(service_type, service_name)
if await info.async_request(zc, int(timeout * 1000)):
return info
except Exception as exc:
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
if do_close:
await zc.async_close()
if async_zc_instance:
await async_zc_instance.async_close()
return info
@ -101,7 +100,7 @@ async def _async_resolve_host_zeroconf(
return []
addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(zeroconf.IPVersion.All):
for ip_address in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6
sockaddr: Sockaddr
if is_ipv6:

View File

@ -1,17 +1,21 @@
import socket
from ipaddress import ip_address
import asyncio
import pytest
from mock import AsyncMock, MagicMock, patch
from zeroconf import DNSCache
from zeroconf.asyncio import AsyncZeroconf, AsyncServiceInfo
import aioesphomeapi.host_resolver as hr
from aioesphomeapi.core import APIConnectionError
@pytest.fixture
def async_zeroconf():
with patch("zeroconf.asyncio.AsyncZeroconf") as klass:
yield klass.return_value
with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
async_zeroconf = klass.return_value
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
yield async_zeroconf
@pytest.fixture
@ -38,31 +42,49 @@ def addr_infos():
@pytest.mark.asyncio
async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
info = MagicMock()
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
async_zeroconf.async_get_service_info = AsyncMock(return_value=info)
async_zeroconf.async_close = AsyncMock()
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
async_zeroconf.async_get_service_info.assert_called_once_with(
"_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000
)
info.async_request.assert_called_once()
async_zeroconf.async_close.assert_called_once_with()
assert ret == addr_infos
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_empty(async_zeroconf):
async_zeroconf.async_get_service_info = AsyncMock(return_value=None)
async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
async_zeroconf = AsyncZeroconf(zc=MagicMock())
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf(
"asdf", 6052, zeroconf_instance=async_zeroconf
)
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
info.async_request.assert_called_once()
async_zeroconf.async_close.assert_not_called()
assert ret == addr_infos
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request"
) as mock_async_request:
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
assert mock_async_request.call_count == 1
assert ret == []