From cb5cea784e4c39eec08560658bb0477fe90c406e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Oct 2023 13:05:23 -1000 Subject: [PATCH] Improve host resolver coverage (#583) --- aioesphomeapi/host_resolver.py | 33 ++++++++++---------- tests/test_host_resolver.py | 56 +++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 81b8960..c28edeb 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -7,12 +7,12 @@ from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address from typing import Union, cast -import zeroconf -import zeroconf.asyncio +from zeroconf import IPVersion, Zeroconf +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf from .core import APIConnectionError, ResolveAPIError -ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None] +ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None] @dataclass(frozen=True) @@ -47,39 +47,38 @@ async def _async_zeroconf_get_service_info( service_type: str, service_name: str, timeout: float, -) -> "zeroconf.asyncio.AsyncServiceInfo" | None: +) -> AsyncServiceInfo | None: # Use or create zeroconf instance, ensure it's an AsyncZeroconf + async_zc_instance: AsyncZeroconf | None = None if zeroconf_instance is None: try: - zc = zeroconf.asyncio.AsyncZeroconf() + async_zc_instance = AsyncZeroconf() except Exception: raise ResolveAPIError( "Cannot start mDNS sockets, is this a docker container without " "host network mode?" ) - do_close = True - elif isinstance(zeroconf_instance, zeroconf.asyncio.AsyncZeroconf): + zc = async_zc_instance.zeroconf + elif isinstance(zeroconf_instance, AsyncZeroconf): + zc = zeroconf_instance.zeroconf + elif isinstance(zeroconf_instance, Zeroconf): zc = zeroconf_instance - do_close = False - elif isinstance(zeroconf_instance, zeroconf.Zeroconf): - zc = zeroconf.asyncio.AsyncZeroconf(zc=zeroconf_instance) - do_close = False else: raise ValueError( f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" ) try: - info = await zc.async_get_service_info( - service_type, service_name, int(timeout * 1000) - ) + info = AsyncServiceInfo(service_type, service_name) + if await info.async_request(zc, int(timeout * 1000)): + return info except Exception as exc: raise ResolveAPIError( f"Error resolving mDNS {service_name} via mDNS: {exc}" ) from exc finally: - if do_close: - await zc.async_close() + if async_zc_instance: + await async_zc_instance.async_close() return info @@ -101,7 +100,7 @@ async def _async_resolve_host_zeroconf( return [] addrs: list[AddrInfo] = [] - for ip_address in info.ip_addresses_by_version(zeroconf.IPVersion.All): + for ip_address in info.ip_addresses_by_version(IPVersion.All): is_ipv6 = ip_address.version == 6 sockaddr: Sockaddr if is_ipv6: diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index fc6a3f4..6afcde5 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -1,17 +1,21 @@ import socket from ipaddress import ip_address - +import asyncio import pytest from mock import AsyncMock, MagicMock, patch - +from zeroconf import DNSCache +from zeroconf.asyncio import AsyncZeroconf, AsyncServiceInfo import aioesphomeapi.host_resolver as hr from aioesphomeapi.core import APIConnectionError @pytest.fixture def async_zeroconf(): - with patch("zeroconf.asyncio.AsyncZeroconf") as klass: - yield klass.return_value + with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass: + async_zeroconf = klass.return_value + async_zeroconf.async_close = AsyncMock() + async_zeroconf.zeroconf.cache = DNSCache() + yield async_zeroconf @pytest.fixture @@ -38,31 +42,49 @@ def addr_infos(): @pytest.mark.asyncio -async def test_resolve_host_zeroconf(async_zeroconf, addr_infos): - info = MagicMock() +async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): + info = MagicMock(auto_spec=AsyncServiceInfo) info.ip_addresses_by_version.return_value = [ ip_address(b"\n\x00\x00*"), ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ] - async_zeroconf.async_get_service_info = AsyncMock(return_value=info) - async_zeroconf.async_close = AsyncMock() + info.async_request = AsyncMock(return_value=True) + with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): + ret = await hr._async_resolve_host_zeroconf("asdf", 6052) - ret = await hr._async_resolve_host_zeroconf("asdf", 6052) - - async_zeroconf.async_get_service_info.assert_called_once_with( - "_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000 - ) + info.async_request.assert_called_once() async_zeroconf.async_close.assert_called_once_with() - assert ret == addr_infos @pytest.mark.asyncio -async def test_resolve_host_zeroconf_empty(async_zeroconf): - async_zeroconf.async_get_service_info = AsyncMock(return_value=None) +async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos): + async_zeroconf = AsyncZeroconf(zc=MagicMock()) async_zeroconf.async_close = AsyncMock() + async_zeroconf.zeroconf.cache = DNSCache() + info = MagicMock(auto_spec=AsyncServiceInfo) + info.ip_addresses_by_version.return_value = [ + ip_address(b"\n\x00\x00*"), + ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), + ] + info.async_request = AsyncMock(return_value=True) + with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): + ret = await hr._async_resolve_host_zeroconf( + "asdf", 6052, zeroconf_instance=async_zeroconf + ) - ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052) + info.async_request.assert_called_once() + async_zeroconf.async_close.assert_not_called() + assert ret == addr_infos + + +@pytest.mark.asyncio +async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf): + with patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request" + ) as mock_async_request: + ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052) + assert mock_async_request.call_count == 1 assert ret == []