mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-10 13:09:48 +01:00
Support for old zeroconfs (#88)
This commit is contained in:
parent
8a89037cda
commit
8d08689b29
@ -1,14 +1,21 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Union, cast
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import zeroconf
|
||||
import zeroconf.asyncio
|
||||
|
||||
try:
|
||||
import zeroconf.asyncio
|
||||
|
||||
ZC_ASYNCIO = True
|
||||
except ImportError:
|
||||
ZC_ASYNCIO = False
|
||||
|
||||
from .core import APIConnectionError
|
||||
|
||||
ZeroconfInstanceType = Union[zeroconf.Zeroconf, zeroconf.asyncio.AsyncZeroconf, None]
|
||||
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -38,13 +45,61 @@ class AddrInfo:
|
||||
sockaddr: Sockaddr
|
||||
|
||||
|
||||
async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
timeout: float = 3.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
) -> List[AddrInfo]:
|
||||
def _sync_zeroconf_get_service_info(
|
||||
zeroconf_instance: ZeroconfInstanceType,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
timeout: float,
|
||||
) -> Optional["zeroconf.ServiceInfo"]:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
if zeroconf_instance is None:
|
||||
try:
|
||||
zc = zeroconf.Zeroconf()
|
||||
except Exception:
|
||||
raise APIConnectionError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
do_close = True
|
||||
elif isinstance(zeroconf_instance, zeroconf.Zeroconf):
|
||||
zc = zeroconf_instance
|
||||
do_close = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
|
||||
)
|
||||
|
||||
try:
|
||||
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
|
||||
except Exception as exc:
|
||||
raise APIConnectionError(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
if do_close:
|
||||
zc.close()
|
||||
return info
|
||||
|
||||
|
||||
async def _async_zeroconf_get_service_info(
|
||||
eventloop: asyncio.events.AbstractEventLoop,
|
||||
zeroconf_instance: ZeroconfInstanceType,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
timeout: float,
|
||||
) -> Optional["zeroconf.ServiceInfo"]:
|
||||
if not ZC_ASYNCIO:
|
||||
return await eventloop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
_sync_zeroconf_get_service_info,
|
||||
zeroconf_instance,
|
||||
service_type,
|
||||
service_name,
|
||||
timeout,
|
||||
),
|
||||
)
|
||||
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
if zeroconf_instance is None:
|
||||
try:
|
||||
@ -66,20 +121,34 @@ async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
|
||||
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
|
||||
)
|
||||
|
||||
service_type = "_esphomelib._tcp.local."
|
||||
service_name = f"{host}.{service_type}"
|
||||
|
||||
try:
|
||||
info = await zc.async_get_service_info(
|
||||
service_type, service_name, int(timeout * 1000)
|
||||
)
|
||||
except Exception as exc:
|
||||
raise APIConnectionError(
|
||||
f"Error resolving host {host} via mDNS: {exc}"
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
if do_close:
|
||||
await zc.async_close()
|
||||
return info
|
||||
|
||||
|
||||
async def _async_resolve_host_zeroconf(
|
||||
eventloop: asyncio.events.AbstractEventLoop,
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
timeout: float = 3.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
) -> List[AddrInfo]:
|
||||
service_type = "_esphomelib._tcp.local."
|
||||
service_name = f"{host}.{service_type}"
|
||||
|
||||
info = await _async_zeroconf_get_service_info(
|
||||
eventloop, zeroconf_instance, service_type, service_name, timeout
|
||||
)
|
||||
|
||||
if info is None:
|
||||
return []
|
||||
@ -158,7 +227,7 @@ async def async_resolve_host(
|
||||
try:
|
||||
addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
name, port, zeroconf_instance=zeroconf_instance
|
||||
eventloop, name, port, zeroconf_instance=zeroconf_instance
|
||||
)
|
||||
)
|
||||
except APIConnectionError as err:
|
||||
|
@ -2,13 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable, List, Optional
|
||||
|
||||
from zeroconf import ( # type: ignore[attr-defined]
|
||||
DNSPointer,
|
||||
DNSRecord,
|
||||
RecordUpdate,
|
||||
RecordUpdateListener,
|
||||
Zeroconf,
|
||||
)
|
||||
import zeroconf
|
||||
|
||||
from .client import APIClient
|
||||
from .core import APIConnectionError
|
||||
@ -16,7 +10,7 @@ from .core import APIConnectionError
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
class ReconnectLogic(zeroconf.RecordUpdateListener): # type: ignore[misc,name-defined]
|
||||
"""Reconnectiong logic handler for ESPHome config entries.
|
||||
|
||||
Contains two reconnect strategies:
|
||||
@ -32,7 +26,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
client: APIClient,
|
||||
on_connect: Callable[[], Awaitable[None]],
|
||||
on_disconnect: Callable[[], Awaitable[None]],
|
||||
zeroconf_instance: Zeroconf,
|
||||
zeroconf_instance: "zeroconf.Zeroconf",
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize ReconnectingLogic.
|
||||
@ -63,6 +57,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
self._wait_task_lock = asyncio.Lock()
|
||||
# Event for tracking when logic should stop
|
||||
self._stop_event = asyncio.Event()
|
||||
self._event_loop: Optional[asyncio.events.AbstractEventLoop] = None
|
||||
|
||||
@property
|
||||
def _is_stopped(self) -> bool:
|
||||
@ -182,6 +177,7 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
"""Start the reconnecting logic background task."""
|
||||
# Create reconnection loop outside of HA's tracked tasks in order
|
||||
# not to delay startup.
|
||||
self._event_loop = asyncio.get_event_loop()
|
||||
self._loop_task = asyncio.create_task(self._reconnect_loop())
|
||||
|
||||
async with self._connected_lock:
|
||||
@ -220,8 +216,8 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
self._zc.async_remove_listener(self)
|
||||
self._zc_listening = False
|
||||
|
||||
def _async_on_record(self, record: DNSRecord) -> None:
|
||||
if not isinstance(record, DNSPointer):
|
||||
def _async_on_record(self, record: "zeroconf.DNSRecord") -> None: # type: ignore[name-defined]
|
||||
if not isinstance(record, zeroconf.DNSPointer): # type: ignore[attr-defined]
|
||||
# We only consider PTR records and match using the alias name
|
||||
return
|
||||
if self._is_stopped or self.name is None:
|
||||
@ -243,9 +239,28 @@ class ReconnectLogic(RecordUpdateListener): # type: ignore[misc]
|
||||
)
|
||||
self._reconnect_event.set()
|
||||
|
||||
# From RecordUpdateListener for zeroconf>=0.32
|
||||
def async_update_records(
|
||||
self, zc: Zeroconf, now: float, records: List[RecordUpdate]
|
||||
self,
|
||||
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument
|
||||
now: float, # pylint: disable=unused-argument
|
||||
records: List["zeroconf.RecordUpdate"], # type: ignore[name-defined]
|
||||
) -> None:
|
||||
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop."""
|
||||
for update in records:
|
||||
self._async_on_record(update.new)
|
||||
|
||||
# From RecordUpdateListener for zeroconf<0.32
|
||||
def update_record(
|
||||
self,
|
||||
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument
|
||||
now: float, # pylint: disable=unused-argument
|
||||
record: "zeroconf.DNSRecord", # type: ignore[name-defined]
|
||||
) -> None:
|
||||
assert self._event_loop is not None
|
||||
|
||||
async def corofunc() -> None:
|
||||
self._async_on_record(record)
|
||||
|
||||
# Dispatch in event loop
|
||||
asyncio.run_coroutine_threadsafe(corofunc(), self._event_loop)
|
||||
|
@ -1,2 +1,2 @@
|
||||
protobuf>=3.12.2,<4.0
|
||||
zeroconf>=0.32.0,<1.0
|
||||
zeroconf>=0.28.0,<1.0
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
@ -45,8 +46,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
|
||||
]
|
||||
async_zeroconf.async_get_service_info = AsyncMock(return_value=info)
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
ret = await hr._async_resolve_host_zeroconf(loop, "asdf", 6052)
|
||||
|
||||
async_zeroconf.async_get_service_info.assert_called_once_with(
|
||||
"_esphomelib._tcp.local.", "asdf._esphomelib._tcp.local.", 3000
|
||||
@ -60,8 +62,9 @@ async def test_resolve_host_zeroconf(async_zeroconf, addr_infos):
|
||||
async def test_resolve_host_zeroconf_empty(async_zeroconf):
|
||||
async_zeroconf.async_get_service_info = AsyncMock(return_value=None)
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf.local", 6052)
|
||||
ret = await hr._async_resolve_host_zeroconf(loop, "asdf.local", 6052)
|
||||
assert ret == []
|
||||
|
||||
|
||||
@ -103,9 +106,10 @@ async def test_resolve_host_getaddrinfo_oserror():
|
||||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_zc.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host(None, "example.local", 6052)
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = await hr.async_resolve_host(loop, "example.local", 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
|
||||
resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None)
|
||||
resolve_addr.assert_not_called()
|
||||
assert ret == addr_infos[0]
|
||||
|
||||
@ -116,10 +120,11 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_zc.return_value = []
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host(None, "example.local", 6052)
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = await hr.async_resolve_host(loop, "example.local", 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
|
||||
resolve_addr.assert_called_once_with(None, "example.local", 6052)
|
||||
resolve_zc.assert_called_once_with(loop, "example", 6052, zeroconf_instance=None)
|
||||
resolve_addr.assert_called_once_with(loop, "example.local", 6052)
|
||||
assert ret == addr_infos[0]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user