diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index fc27717..d8a1e02 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,14 +1,21 @@ import asyncio +import functools import socket from dataclasses import dataclass -from typing import List, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast import zeroconf -import zeroconf.asyncio + +try: + import zeroconf.asyncio + + ZC_ASYNCIO = True +except ImportError: + ZC_ASYNCIO = False from .core import APIConnectionError -ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None] +ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None] @dataclass(frozen=True) @@ -38,13 +45,61 @@ class AddrInfo: sockaddr: Sockaddr -async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches - host: str, - port: int, - *, - timeout: float = 3.0, - zeroconf_instance: ZeroconfInstanceType = None, -) -> List[AddrInfo]: +def _sync_zeroconf_get_service_info( + zeroconf_instance: ZeroconfInstanceType, + service_type: str, + service_name: str, + timeout: float, +) -> Optional["zeroconf.ServiceInfo"]: + # Use or create zeroconf instance, ensure it's an AsyncZeroconf + if zeroconf_instance is None: + try: + zc = zeroconf.Zeroconf() + except Exception: + raise APIConnectionError( + "Cannot start mDNS sockets, is this a docker container without " + "host network mode?" + ) + do_close = True + elif isinstance(zeroconf_instance, zeroconf.Zeroconf): + zc = zeroconf_instance + do_close = False + else: + raise ValueError( + f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" + ) + + try: + info = zc.get_service_info(service_type, service_name, int(timeout * 1000)) + except Exception as exc: + raise APIConnectionError( + f"Error resolving mDNS {service_name} via mDNS: {exc}" + ) from exc + finally: + if do_close: + zc.close() + return info + + +async def _async_zeroconf_get_service_info( + eventloop: asyncio.events.AbstractEventLoop, + zeroconf_instance: ZeroconfInstanceType, + service_type: str, + service_name: str, + timeout: float, +) -> Optional["zeroconf.ServiceInfo"]: + if not ZC_ASYNCIO: + return await eventloop.run_in_executor( + None, + functools.partial( + _sync_zeroconf_get_service_info, + zeroconf_instance, + service_type, + service_name, + timeout, + ), + ) + # Use or create zeroconf instance, ensure it's an AsyncZeroconf if zeroconf_instance is None: try: @@ -66,20 +121,34 @@ async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" ) - service_type = "_esphomelib._tcp.local." - service_name = f"{host}.{service_type}" - try: info = await zc.async_get_service_info( service_type, service_name, int(timeout * 1000) ) except Exception as exc: raise APIConnectionError( - f"Error resolving host {host} via mDNS: {exc}" + f"Error resolving mDNS {service_name} via mDNS: {exc}" ) from exc finally: if do_close: await zc.async_close() + return info + + +async def _async_resolve_host_zeroconf( + eventloop: asyncio.events.AbstractEventLoop, + host: str, + port: int, + *, + timeout: float = 3.0, + zeroconf_instance: ZeroconfInstanceType = None, +) -> List[AddrInfo]: + service_type = "_esphomelib._tcp.local." + service_name = f"{host}.{service_type}" + + info = await _async_zeroconf_get_service_info( + eventloop, zeroconf_instance, service_type, service_name, timeout + ) if info is None: return [] @@ -158,7 +227,7 @@ async def async_resolve_host( try: addrs.extend( await _async_resolve_host_zeroconf( - name, port, zeroconf_instance=zeroconf_instance + eventloop, name, port, zeroconf_instance=zeroconf_instance ) ) except APIConnectionError as err: diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index 5e4ec33..fec280d 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -2,13 +2,7 @@ import asyncio import logging from typing import Awaitable, Callable, List, Optional -from zeroconf import ( # type: ignore[attr-defined] - DNSPointer, - DNSRecord, - RecordUpdate, - RecordUpdateListener, - Zeroconf, -) +import zeroconf from .client import APIClient from .core import APIConnectionError @@ -16,7 +10,7 @@ from .core import APIConnectionError _LOGGER = logging.getLogger(__name__) -class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] +class ReconnectLogic(zeroconf.RecordUpdateListener): # type: ignore[misc,name-defined] """Reconnectiong logic handler for ESPHome config entries. Contains two reconnect strategies: @@ -32,7 +26,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] client: APIClient, on_connect: Callable[[], Awaitable[None]], on_disconnect: Callable[[], Awaitable[None]], - zeroconf_instance: Zeroconf, + zeroconf_instance: "zeroconf.Zeroconf", name: Optional[str] = None, ) -> None: """Initialize ReconnectingLogic. @@ -63,6 +57,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] self._wait_task_lock = asyncio.Lock() # Event for tracking when logic should stop self._stop_event = asyncio.Event() + self._event_loop: Optional[asyncio.events.AbstractEventLoop] = None @property def _is_stopped(self) -> bool: @@ -182,6 +177,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] """Start the reconnecting logic background task.""" # Create reconnection loop outside of HA's tracked tasks in order # not to delay startup. + self._event_loop = asyncio.get_event_loop() self._loop_task = asyncio.create_task(self._reconnect_loop()) async with self._connected_lock: @@ -220,8 +216,8 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] self._zc.async_remove_listener(self) self._zc_listening = False - def _async_on_record(self, record: DNSRecord) -> None: - if not isinstance(record, DNSPointer): + def _async_on_record(self, record: "zeroconf.DNSRecord") -> None: # type: ignore[name-defined] + if not isinstance(record, zeroconf.DNSPointer): # type: ignore[attr-defined] # We only consider PTR records and match using the alias name return if self._is_stopped or self.name is None: @@ -243,9 +239,28 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc] ) self._reconnect_event.set() + # From RecordUpdateListener for zeroconf>=0.32 def async_update_records( - self, zc: Zeroconf, now: float, records: List[RecordUpdate] + self, + zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument + now: float, # pylint: disable=unused-argument + records: List["zeroconf.RecordUpdate"], # type: ignore[name-defined] ) -> None: """Listen to zeroconf updated mDNS records. This must be called from the eventloop.""" for update in records: self._async_on_record(update.new) + + # From RecordUpdateListener for zeroconf<0.32 + def update_record( + self, + zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument + now: float, # pylint: disable=unused-argument + record: "zeroconf.DNSRecord", # type: ignore[name-defined] + ) -> None: + assert self._event_loop is not None + + async def corofunc() -> None: + self._async_on_record(record) + + # Dispatch in event loop + asyncio.run_coroutine_threadsafe(corofunc(), self._event_loop) diff --git a/requirements.txt b/requirements.txt index 9cfded3..1916b64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ protobuf>=3.12.2,<4.0 -zeroconf>=0.32.0,<1.0 +zeroconf>=0.28.0,<1.0 diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 6ccb05f..6287031 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -1,3 +1,4 @@ +import asyncio import socket import pytest @@ -45,8 +46,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos): ] async_zeroconf.async_get_service_info = AsyncMock(return_value=info) async_zeroconf.async_close = AsyncMock() + loop = asyncio.get_event_loop() - ret = await hr._async_resolve_host_zeroconf("asdf", 6052) + ret = await hr._async_resolve_host_zeroconf(loop, "asdf", 6052) async_zeroconf.async_get_service_info.assert_called_once_with( "_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000 @@ -60,8 +62,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos): async def test_resolve_host_zeroconf_empty(async_zeroconf): async_zeroconf.async_get_service_info = AsyncMock(return_value=None) async_zeroconf.async_close = AsyncMock() + loop = asyncio.get_event_loop() - ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052) + ret = await hr._async_resolve_host_zeroconf(loop, "asdf.local", 6052) assert ret == [] @@ -103,9 +106,10 @@ async def test_resolve_host_getaddrinfo_oserror(): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = addr_infos - ret = await hr.async_resolve_host(None, "example.local", 6052) + loop = asyncio.get_event_loop() + ret = await hr.async_resolve_host(loop, "example.local", 6052) - resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) + resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None) resolve_addr.assert_not_called() assert ret == addr_infos[0] @@ -116,10 +120,11 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = [] resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host(None, "example.local", 6052) + loop = asyncio.get_event_loop() + ret = await hr.async_resolve_host(loop, "example.local", 6052) - resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) - resolve_addr.assert_called_once_with(None, "example.local", 6052) + resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None) + resolve_addr.assert_called_once_with(loop, "example.local", 6052) assert ret == addr_infos[0]