Refactor zeroconf code to avoid creating instances when one is unneeded (#643)

This commit is contained in:
J. Nick Koston 2023-11-17 13:11:36 -06:00 committed by GitHub
parent 9a86f449a6
commit b12903e2e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 392 additions and 222 deletions

View File

@ -20,6 +20,8 @@ cdef class APIFrameHelper:
cdef str _log_name
cdef object _debug_enabled
cpdef set_log_name(self, str log_name)
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length)

View File

@ -62,6 +62,10 @@ class APIFrameHelper:
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def set_log_name(self, log_name: str) -> None:
"""Set the log name."""
self._log_name = log_name
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)

View File

@ -111,7 +111,6 @@ from .core import (
UnhandledAPIConnectionError,
to_human_readable_address,
)
from .host_resolver import ZeroconfInstanceType
from .model import (
AlarmControlPanelCommand,
AlarmControlPanelEntityState,
@ -177,6 +176,8 @@ from .model import (
VoiceAssistantCommand,
VoiceAssistantEventType,
)
from .util import build_log_name
from .zeroconf import ZeroconfInstanceType, ZeroconfManager
_LOGGER = logging.getLogger(__name__)
@ -258,10 +259,10 @@ class APIClient:
__slots__ = (
"_params",
"_connection",
"_cached_name",
"cached_name",
"_background_tasks",
"_loop",
"_log_name",
"log_name",
)
def __init__(
@ -272,7 +273,7 @@ class APIClient:
*,
client_info: str = "aioesphomeapi",
keepalive: float = KEEP_ALIVE_FREQUENCY,
zeroconf_instance: ZeroconfInstanceType = None,
zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None,
expected_name: str | None = None,
) -> None:
@ -297,17 +298,21 @@ class APIClient:
password=password,
client_info=client_info,
keepalive=keepalive,
zeroconf_instance=zeroconf_instance,
zeroconf_manager=ZeroconfManager(zeroconf_instance),
# treat empty '' psk string as missing (like password)
noise_psk=_stringify_or_none(noise_psk) or None,
expected_name=_stringify_or_none(expected_name) or None,
)
self._connection: APIConnection | None = None
self._cached_name: str | None = None
self.cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
self._set_log_name()
@property
def zeroconf_manager(self) -> ZeroconfManager:
return self._params.zeroconf_manager
@property
def expected_name(self) -> str | None:
return self._params.expected_name
@ -320,26 +325,23 @@ class APIClient:
def address(self) -> str:
return self._params.address
def _get_log_name(self) -> str:
"""Get the log name of the device."""
address = self.address
address_is_host = address.endswith(".local")
if self._cached_name is not None:
if address_is_host:
return self._cached_name
return f"{self._cached_name} @ {address}"
if address_is_host:
return address[:-6]
return address
def _set_log_name(self) -> None:
"""Set the log name of the device."""
self._log_name = self._get_log_name()
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info.sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,
resolved_address,
)
if self._connection:
self._connection.set_log_name(self.log_name)
def set_cached_name_if_unset(self, name: str) -> None:
"""Set the cached name of the device if not set."""
if not self._cached_name:
self._cached_name = name
if not self.cached_name:
self.cached_name = name
self._set_log_name()
async def connect(
@ -357,7 +359,7 @@ class APIClient:
) -> None:
"""Start connecting to the device."""
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!")
raise APIConnectionError(f"Already connected to {self.log_name}!")
async def _on_stop(expected_disconnect: bool) -> None:
# Hook into on_stop handler to clear connection when stopped
@ -365,9 +367,7 @@ class APIClient:
if on_stop is not None:
await on_stop(expected_disconnect)
self._connection = APIConnection(
self._params, _on_stop, log_name=self._log_name
)
self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name)
try:
await self._connection.start_connection()
@ -377,8 +377,11 @@ class APIClient:
except Exception as e:
self._connection = None
raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}"
f"Unexpected error while connecting to {self.log_name}: {e}"
) from e
# If we resolved the address, we should set the log name now
if self._connection.resolved_addr_info:
self._set_log_name()
async def finish_connection(
self,
@ -394,8 +397,10 @@ class APIClient:
except Exception as e:
self._connection = None
raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}"
f"Unexpected error while connecting to {self.log_name}: {e}"
) from e
if received_name := self._connection.received_name:
self._set_name_from_device(received_name)
async def disconnect(self, force: bool = False) -> None:
if self._connection is None:
@ -408,10 +413,10 @@ class APIClient:
def _check_authenticated(self) -> None:
connection = self._connection
if not connection:
raise APIConnectionError(f"Not connected to {self._log_name}!")
raise APIConnectionError(f"Not connected to {self.log_name}!")
if not connection.is_connected:
raise APIConnectionError(
f"Authenticated connection not ready yet for {self._log_name}; "
f"Authenticated connection not ready yet for {self.log_name}; "
f"current state is {connection.connection_state}!"
)
@ -423,11 +428,14 @@ class APIClient:
DeviceInfoRequest(), DeviceInfoResponse
)
info = DeviceInfo.from_pb(resp)
self._cached_name = info.name
connection.set_log_name(self._log_name)
self._set_log_name()
self._set_name_from_device(info.name)
return info
def _set_name_from_device(self, name: str) -> None:
"""Set the name from a DeviceInfo message."""
self.cached_name = name
self._set_log_name()
async def list_entities_services(
self,
) -> tuple[list[EntityInfo], list[UserService]]:

View File

@ -68,6 +68,8 @@ cdef class APIConnection:
cdef public bint is_connected
cdef bint _handshake_complete
cdef object _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info
cpdef send_message(self, object msg)

View File

@ -49,6 +49,7 @@ from .core import (
TimeoutAPIError,
)
from .model import APIVersion
from .zeroconf import ZeroconfManager
if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
@ -111,7 +112,7 @@ class ConnectionParams:
password: str | None
client_info: str
keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType
zeroconf_manager: ZeroconfManager
noise_psk: str | None
expected_name: str | None
@ -159,6 +160,8 @@ class APIConnection:
"is_connected",
"_handshake_complete",
"_debug_enabled",
"received_name",
"resolved_addr_info",
)
def __init__(
@ -201,10 +204,14 @@ class APIConnection:
self.is_connected = False
self._handshake_complete = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
self.log_name = name
if self._frame_helper is not None:
self._frame_helper.set_log_name(name)
def _cleanup(self) -> None:
"""Clean up all resources that have been allocated.
@ -276,7 +283,7 @@ class APIConnection:
return await hr.async_resolve_host(
self._params.address,
self._params.port,
self._params.zeroconf_instance,
self._params.zeroconf_manager,
)
except asyncio_TimeoutError as err:
raise ResolveAPIError(
@ -427,17 +434,16 @@ class APIConnection:
self.api_version = api_version
expected_name = self._params.expected_name
received_name = resp.name
if (
expected_name is not None
and received_name != ""
and received_name != expected_name
):
raise BadNameAPIError(
f"Expected '{expected_name}' but server sent "
f"a different name: '{received_name}'",
received_name,
)
if received_name := resp.name:
if expected_name is not None and received_name != expected_name:
raise BadNameAPIError(
f"Expected '{expected_name}' but server sent "
f"a different name: '{received_name}'",
received_name,
)
self.received_name = received_name
self.set_log_name(received_name)
def _async_schedule_keep_alive(self, now: _float) -> None:
"""Start the keep alive task."""
@ -506,8 +512,8 @@ class APIConnection:
async def _do_connect(self) -> None:
"""Do the actual connect process."""
in_do_connect.set(True)
addr = await self._connect_resolve_host()
await self._connect_socket_connect(addr)
self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(self.resolved_addr_info)
async def start_connection(self) -> None:
"""Start the connection process.

View File

@ -6,35 +6,38 @@ import logging
import socket
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from typing import Union, cast
from typing import cast
from zeroconf import IPVersion, Zeroconf
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
from zeroconf import IPVersion
from zeroconf.asyncio import AsyncServiceInfo
from .core import APIConnectionError, ResolveAPIError
from .util import address_is_local, host_is_name_part
from .zeroconf import ZeroconfManager
_LOGGER = logging.getLogger(__name__)
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
SERVICE_TYPE = "_esphomelib._tcp.local."
@dataclass(frozen=True)
class Sockaddr:
pass
"""Base socket address."""
address: str
port: int
@dataclass(frozen=True)
class IPv4Sockaddr(Sockaddr):
address: str
port: int
"""IPv4 socket address."""
@dataclass(frozen=True)
class IPv6Sockaddr(Sockaddr):
address: str
port: int
"""IPv6 socket address."""
flowinfo: int
scope_id: int
@ -44,35 +47,23 @@ class AddrInfo:
family: int
type: int
proto: int
sockaddr: Sockaddr
sockaddr: IPv4Sockaddr | IPv6Sockaddr
async def _async_zeroconf_get_service_info(
zeroconf_instance: ZeroconfInstanceType,
zeroconf_manager: ZeroconfManager,
service_type: str,
service_name: str,
timeout: float,
) -> AsyncServiceInfo | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
async_zc_instance: AsyncZeroconf | None = None
if zeroconf_instance is None:
try:
async_zc_instance = AsyncZeroconf()
except Exception:
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
zc = async_zc_instance.zeroconf
elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance.zeroconf
elif isinstance(zeroconf_instance, Zeroconf):
zc = zeroconf_instance
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
try:
zc = zeroconf_manager.get_async_zeroconf().zeroconf
except Exception as exc:
raise ResolveAPIError(
f"Cannot start mDNS sockets: {exc}, is this a docker container without "
"host network mode?"
) from exc
try:
info = AsyncServiceInfo(service_type, service_name)
if await info.async_request(zc, int(timeout * 1000)):
@ -82,8 +73,7 @@ async def _async_zeroconf_get_service_info(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
if async_zc_instance:
await async_zc_instance.async_close()
await zeroconf_manager.async_close()
return info
@ -92,13 +82,13 @@ async def _async_resolve_host_zeroconf(
port: int,
*,
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]:
service_name = f"{host}.{SERVICE_TYPE}"
_LOGGER.debug("Resolving host %s via mDNS", service_name)
info = await _async_zeroconf_get_service_info(
zeroconf_instance, SERVICE_TYPE, service_name, timeout
zeroconf_manager or ZeroconfManager(), SERVICE_TYPE, service_name, timeout
)
if info is None:
@ -107,7 +97,7 @@ async def _async_resolve_host_zeroconf(
addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6
sockaddr: Sockaddr
sockaddr: IPv6Sockaddr | IPv4Sockaddr
if is_ipv6:
sockaddr = IPv6Sockaddr(
address=str(ip_address),
@ -143,7 +133,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
addrs: list[AddrInfo] = []
for family, type_, proto, _, raw in res:
sockaddr: Sockaddr
sockaddr: IPv4Sockaddr | IPv6Sockaddr
if family == socket.AF_INET:
raw = cast(tuple[str, int], raw)
address, port = raw
@ -197,17 +187,17 @@ def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
async def async_resolve_host(
host: str,
port: int,
zeroconf_instance: ZeroconfInstanceType = None,
zeroconf_manager: ZeroconfManager | None = None,
) -> AddrInfo:
addrs: list[AddrInfo] = []
zc_error = None
if "." not in host or host.endswith(".local"):
if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_instance=zeroconf_instance
name, port, zeroconf_manager=zeroconf_manager
)
)
except APIConnectionError as err:

View File

@ -7,8 +7,6 @@ import logging
import sys
from datetime import datetime
from zeroconf.asyncio import AsyncZeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .log_runner import async_run
@ -29,14 +27,11 @@ async def main(argv: list[str]) -> None:
datefmt="%Y-%m-%d %H:%M:%S",
)
aiozc = AsyncZeroconf()
cli = APIClient(
args.address,
args.port,
args.password or "",
noise_psk=args.noise_psk,
zeroconf_instance=aiozc.zeroconf,
keepalive=10,
)
@ -46,12 +41,10 @@ async def main(argv: list[str]) -> None:
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
stop = await async_run(cli, on_log)
try:
while True:
await asyncio.sleep(60)
await asyncio.Event().wait()
finally:
await aiozc.async_close()
await stop()

View File

@ -47,21 +47,16 @@ async def async_run(
) -> None:
_LOGGER.warning("Disconnected from API")
passed_in_zeroconf = aio_zeroconf_instance is not None
aiozc = aio_zeroconf_instance or AsyncZeroconf()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=aiozc.zeroconf,
zeroconf_instance=aio_zeroconf_instance,
name=name,
)
await logic.start()
async def _stop() -> None:
if not passed_in_zeroconf:
await aiozc.async_close()
await logic.stop()
await cli.disconnect()

View File

@ -19,6 +19,8 @@ from .core import (
RequiresEncryptionAPIError,
UnhandledAPIConnectionError,
)
from .util import address_is_local
from .zeroconf import ZeroconfInstanceType
_LOGGER = logging.getLogger(__name__)
@ -62,7 +64,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
client: APIClient,
on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[bool], Awaitable[None]],
zeroconf_instance: zeroconf.Zeroconf,
zeroconf_instance: ZeroconfInstanceType | None = None,
name: str | None = None,
on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
) -> None:
@ -74,21 +76,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""
self.loop = asyncio.get_event_loop()
self._cli = client
self.name: str | None
if client.address.endswith(".local"):
self.name = client.address[:-6]
self._log_name = self.name
elif name:
self.name: str | None = None
if name:
self.name = name
self._log_name = f"{name} @ {self._cli.address}"
self._cli.set_cached_name_if_unset(name)
else:
self.name = None
self._log_name = client.address
elif address_is_local(client.address):
self.name = client.address.partition(".")[0]
if self.name:
self._cli.set_cached_name_if_unset(self.name)
self._on_connect_cb = on_connect
self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance
self._zeroconf_manager = client.zeroconf_manager
if zeroconf_instance is not None:
self._zeroconf_manager.set_instance(zeroconf_instance)
self._ptr_alias: str | None = None
self._a_name: str | None = None
# Flag to check if the device is connected
@ -116,7 +116,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.info(
"Processing %s disconnect from ESPHome API for %s",
disconnect_type,
self._log_name,
self._cli.log_name,
)
# Run disconnect hook
@ -172,7 +172,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.log(
level,
"Can't connect to ESPHome API for %s: %s (%s)",
self._log_name,
self._cli.log_name,
err,
type(err).__name__,
# Print stacktrace if unhandled
@ -197,7 +197,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
finish_connect_time = time.perf_counter()
connect_time = finish_connect_time - start_connect_time
_LOGGER.info(
"Successfully connected to %s in %0.3fs", self._log_name, connect_time
"Successfully connected to %s in %0.3fs", self._cli.log_name, connect_time
)
self._stop_zc_listen()
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
@ -221,7 +221,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
finish_handshake_time = time.perf_counter()
handshake_time = finish_handshake_time - finish_connect_time
_LOGGER.info(
"Successful handshake with %s in %0.3fs", self._log_name, handshake_time
"Successful handshake with %s in %0.3fs", self._cli.log_name, handshake_time
)
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
await self._on_connect_cb()
@ -250,7 +250,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return
_LOGGER.debug(
"%s: Cancelling existing connect task, to try again now!",
self._log_name,
self._cli.log_name,
)
self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None
@ -260,7 +260,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._connect_task = asyncio.create_task(
self._connect_once_or_reschedule(),
name=f"{self._log_name}: aioesphomeapi connect",
name=f"{self._cli.log_name}: aioesphomeapi connect",
)
def _cancel_connect(self, msg: str) -> None:
@ -277,9 +277,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
Must only be called from _call_connect_once
"""
_LOGGER.debug("Trying to connect to %s", self._log_name)
_LOGGER.debug("Trying to connect to %s", self._cli.log_name)
async with self._connected_lock:
_LOGGER.debug("Connected lock acquired for %s", self._log_name)
_LOGGER.debug("Connected lock acquired for %s", self._cli.log_name)
if (
self._connection_state != ReconnectLogicState.DISCONNECTED
or self._is_stopped
@ -291,9 +291,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
wait_time = int(round(min(1.8**tries, 60.0)))
if tries == 1:
_LOGGER.info(
"Trying to connect to %s in the background", self._log_name
"Trying to connect to %s in the background", self._cli.log_name
)
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
_LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time)
if wait_time:
# If we are waiting, start listening for mDNS records
self._start_zc_listen()
@ -311,7 +311,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""Stop the connect logic."""
self._stop_task = asyncio.create_task(
self.stop(),
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback",
name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
)
self._stop_task.add_done_callback(self._remove_stop_task)
@ -342,6 +342,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
ReconnectLogicState.DISCONNECTED
)
await self._zeroconf_manager.async_close()
def _start_zc_listen(self) -> None:
"""Listen for mDNS records.
@ -352,14 +354,18 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
self._a_name = f"{self.name}.local."
self._zc.async_add_listener(self, None)
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_add_listener(
self, None
)
self._zc_listening = True
def _stop_zc_listen(self) -> None:
"""Stop listening for zeroconf updates."""
if self._zc_listening:
_LOGGER.debug("Removing zeroconf listener for %s", self.name)
self._zc.async_remove_listener(self)
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_remove_listener(
self
)
self._zc_listening = False
def async_update_records(
@ -389,7 +395,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
_LOGGER.debug(
"%s: Triggering connect because of received mDNS record %s",
self._log_name,
self._cli.log_name,
record_update.new,
)
# We can't stop the zeroconf listener here because we are in the middle of

View File

@ -24,3 +24,27 @@ def fix_float_single_double_conversion(value: float) -> float:
l10 = math.ceil(math.log10(abs_val))
prec = 7 - l10
return round(value, prec)
def host_is_name_part(address: str) -> bool:
"""Return True if a host is the name part."""
return "." not in address and ":" not in address
def address_is_local(address: str) -> bool:
"""Return True if the address is a local address."""
return address.removesuffix(".").endswith(".local")
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
"""Return a log name for a connection."""
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
preferred_address = resolved_address or address
if (
name
and name != preferred_address
and not preferred_address.startswith(f"{name}.")
):
return f"{name} @ {preferred_address}"
return preferred_address

60
aioesphomeapi/zeroconf.py Normal file
View File

@ -0,0 +1,60 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Union
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf]
_LOGGER = logging.getLogger(__name__)
class ZeroconfManager:
"""Manage the Zeroconf objects.
This class is used to manage the Zeroconf objects. It is used to create
the Zeroconf objects and to close them. It attempts to avoid creating
a Zeroconf object unless one is actually needed.
"""
def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None:
"""Initialize the ZeroconfManager."""
self._created = False
self._aiozc: AsyncZeroconf | None = None
if zeroconf is not None:
self.set_instance(zeroconf)
def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
"""Set the AsyncZeroconf instance."""
if self._aiozc:
if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf:
return
if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc:
self._aiozc = AsyncZeroconf(zc=zc)
return
raise RuntimeError("Zeroconf instance already set to a different instance")
self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc)
def _create_async_zeroconf(self) -> None:
"""Create an AsyncZeroconf instance."""
_LOGGER.debug("Creating new AsyncZeroconf instance")
self._aiozc = AsyncZeroconf()
self._created = True
def get_async_zeroconf(self) -> AsyncZeroconf:
"""Get the AsyncZeroconf instance."""
if not self._aiozc:
self._create_async_zeroconf()
if TYPE_CHECKING:
assert self._aiozc is not None
return self._aiozc
async def async_close(self) -> None:
"""Close the Zeroconf connection."""
if not self._created or not self._aiozc:
return
await self._aiozc.async_close()
self._aiozc = None
self._created = False

View File

@ -4,7 +4,7 @@ import asyncio
import time
from datetime import datetime, timezone
from functools import partial
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from google.protobuf import message
from zeroconf import Zeroconf
@ -27,26 +27,30 @@ PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
def get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
with patch("zeroconf.Zeroconf.start"):
zc = Zeroconf()
zc.close = MagicMock()
return zc
def get_mock_async_zeroconf() -> MagicMock:
mock = MagicMock(spec=AsyncZeroconf)
mock.zeroconf = get_mock_zeroconf()
mock.async_close = AsyncMock()
return mock
def get_mock_async_zeroconf() -> AsyncZeroconf:
aiozc = AsyncZeroconf(zc=get_mock_zeroconf())
aiozc.async_close = AsyncMock()
return aiozc
class Estr(str):
"""A subclassed string."""
def generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
def generate_plaintext_packet(msg: message.Message) -> bytes:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return (
b"\0"
+ _cached_varuint_to_bytes(len(msg))
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_)
+ msg
+ bytes_
)
@ -99,10 +103,7 @@ def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(generate_plaintext_packet(hello_response))
def send_plaintext_connect_response(
@ -110,8 +111,4 @@ def send_plaintext_connect_response(
) -> None:
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
protocol.data_received(generate_plaintext_packet(connect_response))

View File

@ -12,8 +12,14 @@ from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import connect, send_plaintext_hello
from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
@pytest.fixture
def async_zeroconf():
return get_mock_async_zeroconf()
@pytest.fixture
@ -42,7 +48,7 @@ def connection_params() -> ConnectionParams:
password=None,
client_info="Tests client",
keepalive=15.0,
zeroconf_instance=None,
zeroconf_manager=ZeroconfManager(),
noise_psk=None,
expected_name=None,
)

View File

@ -141,6 +141,7 @@ def test_plaintext_frame_helper(
assert type_ == pkt_type
assert data == pkt_data
helper.close()
@pytest.mark.parametrize(

View File

@ -18,6 +18,8 @@ from aioesphomeapi.api_pb2 import (
CameraImageResponse,
ClimateCommandRequest,
CoverCommandRequest,
DeviceInfoResponse,
DisconnectResponse,
ExecuteServiceArgument,
ExecuteServiceRequest,
FanCommandRequest,
@ -34,6 +36,7 @@ from aioesphomeapi.api_pb2 import (
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import (
AlarmControlPanelCommand,
APIVersion,
@ -680,12 +683,7 @@ async def test_bluetooth_disconnect(
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False
)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse],
)
)
protocol.data_received(generate_plaintext_packet(response))
await disconnect_task
@ -700,12 +698,7 @@ async def test_bluetooth_pair(
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDevicePairingResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse],
)
)
protocol.data_received(generate_plaintext_packet(response))
await pair_task
@ -720,12 +713,7 @@ async def test_bluetooth_unpair(
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse],
)
)
protocol.data_received(generate_plaintext_packet(response))
await unpair_task
@ -740,10 +728,36 @@ async def test_bluetooth_clear_cache(
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse],
)
)
protocol.data_received(generate_plaintext_packet(response))
await clear_task
@pytest.mark.asyncio
async def test_device_info(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test fetching device info."""
client, connection, transport, protocol = api_client
assert client.log_name == "mydevice.local"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
protocol.data_received(generate_plaintext_packet(response))
device_info = await device_info_task
assert device_info.name == "realname"
assert device_info.friendly_name == "My Device"
assert device_info.has_deep_sleep
assert client.log_name == "realname @ 10.0.0.512"
disconnect_task = asyncio.create_task(client.disconnect())
await asyncio.sleep(0)
response: message.Message = DisconnectResponse()
protocol.data_received(generate_plaintext_packet(response))
await disconnect_task
with pytest.raises(APIConnectionError, match="CLOSED"):
await client.device_info()

View File

@ -5,20 +5,11 @@ from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from zeroconf import DNSCache
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
import aioesphomeapi.host_resolver as hr
from aioesphomeapi.core import APIConnectionError
@pytest.fixture
def async_zeroconf():
with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
async_zeroconf = klass.return_value
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
yield async_zeroconf
from aioesphomeapi.core import APIConnectionError, ResolveAPIError
from aioesphomeapi.zeroconf import ZeroconfManager
@pytest.fixture
@ -52,7 +43,9 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info
), patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
info.async_request.assert_called_once()
@ -61,10 +54,8 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
@pytest.mark.asyncio
async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
async_zeroconf = AsyncZeroconf(zc=MagicMock())
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
zeroconf_manager = ZeroconfManager()
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
@ -73,11 +64,10 @@ async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf(
"asdf", 6052, zeroconf_instance=async_zeroconf
"asdf", 6052, zeroconf_manager=zeroconf_manager
)
info.async_request.assert_called_once()
async_zeroconf.async_close.assert_not_called()
assert ret == addr_infos
@ -131,7 +121,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called()
assert ret == addr_infos[0]
@ -144,7 +134,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052)
assert ret == addr_infos[0]
@ -189,3 +179,40 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc):
proto=6,
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
)
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_service_info_oserror(
async_zeroconf: AsyncZeroconf, addr_infos
):
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=OSError("out of buffers"),
), patch(
"aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf
), pytest.raises(
ResolveAPIError, match="out of buffers"
):
await hr._async_resolve_host_zeroconf("asdf", 6052)
@pytest.mark.asyncio
async def test_resolve_host_create_zeroconf_oserror(
async_zeroconf: AsyncZeroconf, addr_infos
):
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch(
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
), pytest.raises(ResolveAPIError, match="out of buffers"):
await hr._async_resolve_host_zeroconf("asdf", 6052)

View File

@ -75,21 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world"
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
)
)
protocol.data_received(generate_plaintext_packet(response))
assert len(messages) == 1
assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
protocol.data_received(
generate_plaintext_packet(
disconnect_response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
)
)
protocol.data_received(generate_plaintext_packet(disconnect_response))
await stop_task

View File

@ -41,14 +41,13 @@ async def test_reconnect_logic_name_from_host():
async def on_connect() -> None:
pass
rl = ReconnectLogic(
ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
)
assert rl._log_name == "mydevice"
assert cli._log_name == "mydevice"
assert cli.log_name == "mydevice.local"
@pytest.mark.asyncio
@ -66,15 +65,14 @@ async def test_reconnect_logic_name_from_host_and_set():
async def on_connect() -> None:
pass
rl = ReconnectLogic(
ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(),
name="mydevice",
)
assert rl._log_name == "mydevice"
assert cli._log_name == "mydevice"
assert cli.log_name == "mydevice.local"
@pytest.mark.asyncio
@ -92,14 +90,13 @@ async def test_reconnect_logic_name_from_address():
async def on_connect() -> None:
pass
rl = ReconnectLogic(
ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(),
)
assert rl._log_name == "1.2.3.4"
assert cli._log_name == "1.2.3.4"
assert cli.log_name == "1.2.3.4"
@pytest.mark.asyncio
@ -117,15 +114,14 @@ async def test_reconnect_logic_name_from_name():
async def on_connect() -> None:
pass
rl = ReconnectLogic(
ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(),
name="mydevice",
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
assert cli.log_name == "mydevice @ 1.2.3.4"
@pytest.mark.asyncio
@ -164,8 +160,7 @@ async def test_reconnect_logic_state():
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
@ -241,8 +236,7 @@ async def test_reconnect_retry():
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
@ -338,8 +332,7 @@ async def test_reconnect_zeroconf(
name="mydevice",
on_connect_error=AsyncMock(),
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
assert cli.log_name == "mydevice @ 1.2.3.4"
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)

52
tests/test_zeroconf.py Normal file
View File

@ -0,0 +1,52 @@
from __future__ import annotations
from unittest.mock import patch
import pytest
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import get_mock_async_zeroconf
@pytest.mark.asyncio
async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf):
"""Test that the passed in instance is not closed."""
manager = ZeroconfManager()
manager.set_instance(async_zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0
@pytest.mark.asyncio
async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf):
"""Test that the passed in instance is not closed."""
manager = ZeroconfManager()
manager.set_instance(async_zeroconf.zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0
@pytest.mark.asyncio
async def test_closes_created_instance(async_zeroconf: AsyncZeroconf):
"""Test that the created instance is closed."""
with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
manager = ZeroconfManager()
assert manager.get_async_zeroconf() is async_zeroconf
await manager.async_close()
assert async_zeroconf.async_close.call_count == 1
@pytest.mark.asyncio
async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf):
"""Test runtime error is raised on multiple instances."""
manager = ZeroconfManager(async_zeroconf)
new_instance = get_mock_async_zeroconf()
with pytest.raises(RuntimeError):
manager.set_instance(new_instance)
manager.set_instance(async_zeroconf)
manager.set_instance(async_zeroconf.zeroconf)
manager.set_instance(async_zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0