Support for old zeroconfs (#88)

This commit is contained in:
Otto Winter 2021-09-07 18:52:54 +02:00 committed by GitHub
parent 8a89037cda
commit 8d08689b29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 35 deletions

View File

@ -1,14 +1,21 @@
import asyncio
import functools
import socket
from dataclasses import dataclass
from typing import List, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast
import zeroconf
import zeroconf.asyncio
try:
import zeroconf.asyncio
ZC_ASYNCIO = True
except ImportError:
ZC_ASYNCIO = False
from .core import APIConnectionError
ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None]
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
@dataclass(frozen=True)
@ -38,13 +45,61 @@ class AddrInfo:
sockaddr: Sockaddr
async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
host: str,
port: int,
*,
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
) -> List[AddrInfo]:
def _sync_zeroconf_get_service_info(
zeroconf_instance: ZeroconfInstanceType,
service_type: str,
service_name: str,
timeout: float,
) -> Optional["zeroconf.ServiceInfo"]:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None:
try:
zc = zeroconf.Zeroconf()
except Exception:
raise APIConnectionError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
do_close = True
elif isinstance(zeroconf_instance, zeroconf.Zeroconf):
zc = zeroconf_instance
do_close = False
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
try:
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
except Exception as exc:
raise APIConnectionError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
if do_close:
zc.close()
return info
async def _async_zeroconf_get_service_info(
eventloop: asyncio.events.AbstractEventLoop,
zeroconf_instance: ZeroconfInstanceType,
service_type: str,
service_name: str,
timeout: float,
) -> Optional["zeroconf.ServiceInfo"]:
if not ZC_ASYNCIO:
return await eventloop.run_in_executor(
None,
functools.partial(
_sync_zeroconf_get_service_info,
zeroconf_instance,
service_type,
service_name,
timeout,
),
)
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None:
try:
@ -66,20 +121,34 @@ async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
try:
info = await zc.async_get_service_info(
service_type, service_name, int(timeout * 1000)
)
except Exception as exc:
raise APIConnectionError(
f"Error resolving host {host} via mDNS: {exc}"
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
if do_close:
await zc.async_close()
return info
async def _async_resolve_host_zeroconf(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
*,
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
) -> List[AddrInfo]:
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
info = await _async_zeroconf_get_service_info(
eventloop, zeroconf_instance, service_type, service_name, timeout
)
if info is None:
return []
@ -158,7 +227,7 @@ async def async_resolve_host(
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_instance=zeroconf_instance
eventloop, name, port, zeroconf_instance=zeroconf_instance
)
)
except APIConnectionError as err:

View File

@ -2,13 +2,7 @@ import asyncio
import logging
from typing import Awaitable, Callable, List, Optional
from zeroconf import ( # type: ignore[attr-defined]
DNSPointer,
DNSRecord,
RecordUpdate,
RecordUpdateListener,
Zeroconf,
)
import zeroconf
from .client import APIClient
from .core import APIConnectionError
@ -16,7 +10,7 @@ from .core import APIConnectionError
_LOGGER = logging.getLogger(__name__)
class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
class ReconnectLogic(zeroconf.RecordUpdateListener): # type: ignore[misc,name-defined]
"""Reconnectiong logic handler for ESPHome config entries.
Contains two reconnect strategies:
@ -32,7 +26,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
client: APIClient,
on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[], Awaitable[None]],
zeroconf_instance: Zeroconf,
zeroconf_instance: "zeroconf.Zeroconf",
name: Optional[str] = None,
) -> None:
"""Initialize ReconnectingLogic.
@ -63,6 +57,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
self._wait_task_lock = asyncio.Lock()
# Event for tracking when logic should stop
self._stop_event = asyncio.Event()
self._event_loop: Optional[asyncio.events.AbstractEventLoop] = None
@property
def _is_stopped(self) -> bool:
@ -182,6 +177,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
"""Start the reconnecting logic background task."""
# Create reconnection loop outside of HA's tracked tasks in order
# not to delay startup.
self._event_loop = asyncio.get_event_loop()
self._loop_task = asyncio.create_task(self._reconnect_loop())
async with self._connected_lock:
@ -220,8 +216,8 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
self._zc.async_remove_listener(self)
self._zc_listening = False
def _async_on_record(self, record: DNSRecord) -> None:
if not isinstance(record, DNSPointer):
def _async_on_record(self, record: "zeroconf.DNSRecord") -> None: # type: ignore[name-defined]
if not isinstance(record, zeroconf.DNSPointer): # type: ignore[attr-defined]
# We only consider PTR records and match using the alias name
return
if self._is_stopped or self.name is None:
@ -243,9 +239,28 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
)
self._reconnect_event.set()
# From RecordUpdateListener for zeroconf>=0.32
def async_update_records(
self, zc: Zeroconf, now: float, records: List[RecordUpdate]
self,
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument
now: float, # pylint: disable=unused-argument
records: List["zeroconf.RecordUpdate"], # type: ignore[name-defined]
) -> None:
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop."""
for update in records:
self._async_on_record(update.new)
# From RecordUpdateListener for zeroconf<0.32
def update_record(
self,
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument
now: float, # pylint: disable=unused-argument
record: "zeroconf.DNSRecord", # type: ignore[name-defined]
) -> None:
assert self._event_loop is not None
async def corofunc() -> None:
self._async_on_record(record)
# Dispatch in event loop
asyncio.run_coroutine_threadsafe(corofunc(), self._event_loop)

View File

@ -1,2 +1,2 @@
protobuf>=3.12.2,<4.0
zeroconf>=0.32.0,<1.0
zeroconf>=0.28.0,<1.0

View File

@ -1,3 +1,4 @@
import asyncio
import socket
import pytest
@ -45,8 +46,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
]
async_zeroconf.async_get_service_info = AsyncMock(return_value=info)
async_zeroconf.async_close = AsyncMock()
loop = asyncio.get_event_loop()
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
ret = await hr._async_resolve_host_zeroconf(loop, "asdf", 6052)
async_zeroconf.async_get_service_info.assert_called_once_with(
"_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000
@ -60,8 +62,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
async def test_resolve_host_zeroconf_empty(async_zeroconf):
async_zeroconf.async_get_service_info = AsyncMock(return_value=None)
async_zeroconf.async_close = AsyncMock()
loop = asyncio.get_event_loop()
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
ret = await hr._async_resolve_host_zeroconf(loop, "asdf.local", 6052)
assert ret == []
@ -103,9 +106,10 @@ async def test_resolve_host_getaddrinfo_oserror():
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host(None, "example.local", 6052)
loop = asyncio.get_event_loop()
ret = await hr.async_resolve_host(loop, "example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None)
resolve_addr.assert_not_called()
assert ret == addr_infos[0]
@ -116,10 +120,11 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = []
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host(None, "example.local", 6052)
loop = asyncio.get_event_loop()
ret = await hr.async_resolve_host(loop, "example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
resolve_addr.assert_called_once_with(None, "example.local", 6052)
resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None)
resolve_addr.assert_called_once_with(loop, "example.local", 6052)
assert ret == addr_infos[0]