mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-21 16:37:41 +01:00
Improve host resolver coverage (#583)
This commit is contained in:
parent
2a78e9588e
commit
cb5cea784e
@ -7,12 +7,12 @@ from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import Union, cast
|
||||
|
||||
import zeroconf
|
||||
import zeroconf.asyncio
|
||||
from zeroconf import IPVersion, Zeroconf
|
||||
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
from .core import APIConnectionError, ResolveAPIError
|
||||
|
||||
ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None]
|
||||
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -47,39 +47,38 @@ async def _async_zeroconf_get_service_info(
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
timeout: float,
|
||||
) -> "zeroconf.asyncio.AsyncServiceInfo" | None:
|
||||
) -> AsyncServiceInfo | None:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
async_zc_instance: AsyncZeroconf | None = None
|
||||
if zeroconf_instance is None:
|
||||
try:
|
||||
zc = zeroconf.asyncio.AsyncZeroconf()
|
||||
async_zc_instance = AsyncZeroconf()
|
||||
except Exception:
|
||||
raise ResolveAPIError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
do_close = True
|
||||
elif isinstance(zeroconf_instance, zeroconf.asyncio.AsyncZeroconf):
|
||||
zc = async_zc_instance.zeroconf
|
||||
elif isinstance(zeroconf_instance, AsyncZeroconf):
|
||||
zc = zeroconf_instance.zeroconf
|
||||
elif isinstance(zeroconf_instance, Zeroconf):
|
||||
zc = zeroconf_instance
|
||||
do_close = False
|
||||
elif isinstance(zeroconf_instance, zeroconf.Zeroconf):
|
||||
zc = zeroconf.asyncio.AsyncZeroconf(zc=zeroconf_instance)
|
||||
do_close = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
|
||||
)
|
||||
|
||||
try:
|
||||
info = await zc.async_get_service_info(
|
||||
service_type, service_name, int(timeout * 1000)
|
||||
)
|
||||
info = AsyncServiceInfo(service_type, service_name)
|
||||
if await info.async_request(zc, int(timeout * 1000)):
|
||||
return info
|
||||
except Exception as exc:
|
||||
raise ResolveAPIError(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
if do_close:
|
||||
await zc.async_close()
|
||||
if async_zc_instance:
|
||||
await async_zc_instance.async_close()
|
||||
return info
|
||||
|
||||
|
||||
@ -101,7 +100,7 @@ async def _async_resolve_host_zeroconf(
|
||||
return []
|
||||
|
||||
addrs: list[AddrInfo] = []
|
||||
for ip_address in info.ip_addresses_by_version(zeroconf.IPVersion.All):
|
||||
for ip_address in info.ip_addresses_by_version(IPVersion.All):
|
||||
is_ipv6 = ip_address.version == 6
|
||||
sockaddr: Sockaddr
|
||||
if is_ipv6:
|
||||
|
@ -1,17 +1,21 @@
|
||||
import socket
|
||||
from ipaddress import ip_address
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from zeroconf import DNSCache
|
||||
from zeroconf.asyncio import AsyncZeroconf, AsyncServiceInfo
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_zeroconf():
|
||||
with patch("zeroconf.asyncio.AsyncZeroconf") as klass:
|
||||
yield klass.return_value
|
||||
with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
|
||||
async_zeroconf = klass.return_value
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
async_zeroconf.zeroconf.cache = DNSCache()
|
||||
yield async_zeroconf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -38,31 +42,49 @@ def addr_infos():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
|
||||
info = MagicMock()
|
||||
async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
]
|
||||
async_zeroconf.async_get_service_info = AsyncMock(return_value=info)
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
||||
async_zeroconf.async_get_service_info.assert_called_once_with(
|
||||
"_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000
|
||||
)
|
||||
info.async_request.assert_called_once()
|
||||
async_zeroconf.async_close.assert_called_once_with()
|
||||
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf_empty(async_zeroconf):
|
||||
async_zeroconf.async_get_service_info = AsyncMock(return_value=None)
|
||||
async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
|
||||
async_zeroconf = AsyncZeroconf(zc=MagicMock())
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
async_zeroconf.zeroconf.cache = DNSCache()
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
ret = await hr._async_resolve_host_zeroconf(
|
||||
"asdf", 6052, zeroconf_instance=async_zeroconf
|
||||
)
|
||||
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
|
||||
info.async_request.assert_called_once()
|
||||
async_zeroconf.async_close.assert_not_called()
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf):
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request"
|
||||
) as mock_async_request:
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
|
||||
assert mock_async_request.call_count == 1
|
||||
assert ret == []
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user