Refactor zeroconf code to avoid creating instances when one is unneeded (#643)

This commit is contained in:
J. Nick Koston 2023-11-17 13:11:36 -06:00 committed by GitHub
parent 9a86f449a6
commit b12903e2e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 392 additions and 222 deletions

View File

@ -20,6 +20,8 @@ cdef class APIFrameHelper:
cdef str _log_name cdef str _log_name
cdef object _debug_enabled cdef object _debug_enabled
cpdef set_log_name(self, str log_name)
@cython.locals(original_pos="unsigned int", new_pos="unsigned int") @cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length) cdef bytes _read_exactly(self, int length)

View File

@ -62,6 +62,10 @@ class APIFrameHelper:
self._log_name = log_name self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def set_log_name(self, log_name: str) -> None:
"""Set the log name."""
self._log_name = log_name
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None: def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done(): if not self._ready_future.done():
self._ready_future.set_exception(exc) self._ready_future.set_exception(exc)

View File

@ -111,7 +111,6 @@ from .core import (
UnhandledAPIConnectionError, UnhandledAPIConnectionError,
to_human_readable_address, to_human_readable_address,
) )
from .host_resolver import ZeroconfInstanceType
from .model import ( from .model import (
AlarmControlPanelCommand, AlarmControlPanelCommand,
AlarmControlPanelEntityState, AlarmControlPanelEntityState,
@ -177,6 +176,8 @@ from .model import (
VoiceAssistantCommand, VoiceAssistantCommand,
VoiceAssistantEventType, VoiceAssistantEventType,
) )
from .util import build_log_name
from .zeroconf import ZeroconfInstanceType, ZeroconfManager
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -258,10 +259,10 @@ class APIClient:
__slots__ = ( __slots__ = (
"_params", "_params",
"_connection", "_connection",
"_cached_name", "cached_name",
"_background_tasks", "_background_tasks",
"_loop", "_loop",
"_log_name", "log_name",
) )
def __init__( def __init__(
@ -272,7 +273,7 @@ class APIClient:
*, *,
client_info: str = "aioesphomeapi", client_info: str = "aioesphomeapi",
keepalive: float = KEEP_ALIVE_FREQUENCY, keepalive: float = KEEP_ALIVE_FREQUENCY,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None, noise_psk: str | None = None,
expected_name: str | None = None, expected_name: str | None = None,
) -> None: ) -> None:
@ -297,17 +298,21 @@ class APIClient:
password=password, password=password,
client_info=client_info, client_info=client_info,
keepalive=keepalive, keepalive=keepalive,
zeroconf_instance=zeroconf_instance, zeroconf_manager=ZeroconfManager(zeroconf_instance),
# treat empty '' psk string as missing (like password) # treat empty '' psk string as missing (like password)
noise_psk=_stringify_or_none(noise_psk) or None, noise_psk=_stringify_or_none(noise_psk) or None,
expected_name=_stringify_or_none(expected_name) or None, expected_name=_stringify_or_none(expected_name) or None,
) )
self._connection: APIConnection | None = None self._connection: APIConnection | None = None
self._cached_name: str | None = None self.cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set() self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._set_log_name() self._set_log_name()
@property
def zeroconf_manager(self) -> ZeroconfManager:
return self._params.zeroconf_manager
@property @property
def expected_name(self) -> str | None: def expected_name(self) -> str | None:
return self._params.expected_name return self._params.expected_name
@ -320,26 +325,23 @@ class APIClient:
def address(self) -> str: def address(self) -> str:
return self._params.address return self._params.address
def _get_log_name(self) -> str:
"""Get the log name of the device."""
address = self.address
address_is_host = address.endswith(".local")
if self._cached_name is not None:
if address_is_host:
return self._cached_name
return f"{self._cached_name} @ {address}"
if address_is_host:
return address[:-6]
return address
def _set_log_name(self) -> None: def _set_log_name(self) -> None:
"""Set the log name of the device.""" """Set the log name of the device."""
self._log_name = self._get_log_name() resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info.sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,
resolved_address,
)
if self._connection:
self._connection.set_log_name(self.log_name)
def set_cached_name_if_unset(self, name: str) -> None: def set_cached_name_if_unset(self, name: str) -> None:
"""Set the cached name of the device if not set.""" """Set the cached name of the device if not set."""
if not self._cached_name: if not self.cached_name:
self._cached_name = name self.cached_name = name
self._set_log_name() self._set_log_name()
async def connect( async def connect(
@ -357,7 +359,7 @@ class APIClient:
) -> None: ) -> None:
"""Start connecting to the device.""" """Start connecting to the device."""
if self._connection is not None: if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!") raise APIConnectionError(f"Already connected to {self.log_name}!")
async def _on_stop(expected_disconnect: bool) -> None: async def _on_stop(expected_disconnect: bool) -> None:
# Hook into on_stop handler to clear connection when stopped # Hook into on_stop handler to clear connection when stopped
@ -365,9 +367,7 @@ class APIClient:
if on_stop is not None: if on_stop is not None:
await on_stop(expected_disconnect) await on_stop(expected_disconnect)
self._connection = APIConnection( self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name)
self._params, _on_stop, log_name=self._log_name
)
try: try:
await self._connection.start_connection() await self._connection.start_connection()
@ -377,8 +377,11 @@ class APIClient:
except Exception as e: except Exception as e:
self._connection = None self._connection = None
raise UnhandledAPIConnectionError( raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}" f"Unexpected error while connecting to {self.log_name}: {e}"
) from e ) from e
# If we resolved the address, we should set the log name now
if self._connection.resolved_addr_info:
self._set_log_name()
async def finish_connection( async def finish_connection(
self, self,
@ -394,8 +397,10 @@ class APIClient:
except Exception as e: except Exception as e:
self._connection = None self._connection = None
raise UnhandledAPIConnectionError( raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}" f"Unexpected error while connecting to {self.log_name}: {e}"
) from e ) from e
if received_name := self._connection.received_name:
self._set_name_from_device(received_name)
async def disconnect(self, force: bool = False) -> None: async def disconnect(self, force: bool = False) -> None:
if self._connection is None: if self._connection is None:
@ -408,10 +413,10 @@ class APIClient:
def _check_authenticated(self) -> None: def _check_authenticated(self) -> None:
connection = self._connection connection = self._connection
if not connection: if not connection:
raise APIConnectionError(f"Not connected to {self._log_name}!") raise APIConnectionError(f"Not connected to {self.log_name}!")
if not connection.is_connected: if not connection.is_connected:
raise APIConnectionError( raise APIConnectionError(
f"Authenticated connection not ready yet for {self._log_name}; " f"Authenticated connection not ready yet for {self.log_name}; "
f"current state is {connection.connection_state}!" f"current state is {connection.connection_state}!"
) )
@ -423,11 +428,14 @@ class APIClient:
DeviceInfoRequest(), DeviceInfoResponse DeviceInfoRequest(), DeviceInfoResponse
) )
info = DeviceInfo.from_pb(resp) info = DeviceInfo.from_pb(resp)
self._cached_name = info.name self._set_name_from_device(info.name)
connection.set_log_name(self._log_name)
self._set_log_name()
return info return info
def _set_name_from_device(self, name: str) -> None:
"""Set the name from a DeviceInfo message."""
self.cached_name = name
self._set_log_name()
async def list_entities_services( async def list_entities_services(
self, self,
) -> tuple[list[EntityInfo], list[UserService]]: ) -> tuple[list[EntityInfo], list[UserService]]:

View File

@ -68,6 +68,8 @@ cdef class APIConnection:
cdef public bint is_connected cdef public bint is_connected
cdef bint _handshake_complete cdef bint _handshake_complete
cdef object _debug_enabled cdef object _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info
cpdef send_message(self, object msg) cpdef send_message(self, object msg)

View File

@ -49,6 +49,7 @@ from .core import (
TimeoutAPIError, TimeoutAPIError,
) )
from .model import APIVersion from .model import APIVersion
from .zeroconf import ZeroconfManager
if sys.version_info[:2] < (3, 11): if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
@ -111,7 +112,7 @@ class ConnectionParams:
password: str | None password: str | None
client_info: str client_info: str
keepalive: float keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType zeroconf_manager: ZeroconfManager
noise_psk: str | None noise_psk: str | None
expected_name: str | None expected_name: str | None
@ -159,6 +160,8 @@ class APIConnection:
"is_connected", "is_connected",
"_handshake_complete", "_handshake_complete",
"_debug_enabled", "_debug_enabled",
"received_name",
"resolved_addr_info",
) )
def __init__( def __init__(
@ -201,10 +204,14 @@ class APIConnection:
self.is_connected = False self.is_connected = False
self._handshake_complete = False self._handshake_complete = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
def set_log_name(self, name: str) -> None: def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection.""" """Set the friendly log name for this connection."""
self.log_name = name self.log_name = name
if self._frame_helper is not None:
self._frame_helper.set_log_name(name)
def _cleanup(self) -> None: def _cleanup(self) -> None:
"""Clean up all resources that have been allocated. """Clean up all resources that have been allocated.
@ -276,7 +283,7 @@ class APIConnection:
return await hr.async_resolve_host( return await hr.async_resolve_host(
self._params.address, self._params.address,
self._params.port, self._params.port,
self._params.zeroconf_instance, self._params.zeroconf_manager,
) )
except asyncio_TimeoutError as err: except asyncio_TimeoutError as err:
raise ResolveAPIError( raise ResolveAPIError(
@ -427,17 +434,16 @@ class APIConnection:
self.api_version = api_version self.api_version = api_version
expected_name = self._params.expected_name expected_name = self._params.expected_name
received_name = resp.name if received_name := resp.name:
if ( if expected_name is not None and received_name != expected_name:
expected_name is not None raise BadNameAPIError(
and received_name != "" f"Expected '{expected_name}' but server sent "
and received_name != expected_name f"a different name: '{received_name}'",
): received_name,
raise BadNameAPIError( )
f"Expected '{expected_name}' but server sent "
f"a different name: '{received_name}'", self.received_name = received_name
received_name, self.set_log_name(received_name)
)
def _async_schedule_keep_alive(self, now: _float) -> None: def _async_schedule_keep_alive(self, now: _float) -> None:
"""Start the keep alive task.""" """Start the keep alive task."""
@ -506,8 +512,8 @@ class APIConnection:
async def _do_connect(self) -> None: async def _do_connect(self) -> None:
"""Do the actual connect process.""" """Do the actual connect process."""
in_do_connect.set(True) in_do_connect.set(True)
addr = await self._connect_resolve_host() self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(addr) await self._connect_socket_connect(self.resolved_addr_info)
async def start_connection(self) -> None: async def start_connection(self) -> None:
"""Start the connection process. """Start the connection process.

View File

@ -6,35 +6,38 @@ import logging
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from typing import Union, cast from typing import cast
from zeroconf import IPVersion, Zeroconf from zeroconf import IPVersion
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf from zeroconf.asyncio import AsyncServiceInfo
from .core import APIConnectionError, ResolveAPIError from .core import APIConnectionError, ResolveAPIError
from .util import address_is_local, host_is_name_part
from .zeroconf import ZeroconfManager
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
SERVICE_TYPE = "_esphomelib._tcp.local." SERVICE_TYPE = "_esphomelib._tcp.local."
@dataclass(frozen=True) @dataclass(frozen=True)
class Sockaddr: class Sockaddr:
pass """Base socket address."""
address: str
port: int
@dataclass(frozen=True) @dataclass(frozen=True)
class IPv4Sockaddr(Sockaddr): class IPv4Sockaddr(Sockaddr):
address: str """IPv4 socket address."""
port: int
@dataclass(frozen=True) @dataclass(frozen=True)
class IPv6Sockaddr(Sockaddr): class IPv6Sockaddr(Sockaddr):
address: str """IPv6 socket address."""
port: int
flowinfo: int flowinfo: int
scope_id: int scope_id: int
@ -44,35 +47,23 @@ class AddrInfo:
family: int family: int
type: int type: int
proto: int proto: int
sockaddr: Sockaddr sockaddr: IPv4Sockaddr | IPv6Sockaddr
async def _async_zeroconf_get_service_info( async def _async_zeroconf_get_service_info(
zeroconf_instance: ZeroconfInstanceType, zeroconf_manager: ZeroconfManager,
service_type: str, service_type: str,
service_name: str, service_name: str,
timeout: float, timeout: float,
) -> AsyncServiceInfo | None: ) -> AsyncServiceInfo | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf # Use or create zeroconf instance, ensure it's an AsyncZeroconf
async_zc_instance: AsyncZeroconf | None = None try:
if zeroconf_instance is None: zc = zeroconf_manager.get_async_zeroconf().zeroconf
try: except Exception as exc:
async_zc_instance = AsyncZeroconf() raise ResolveAPIError(
except Exception: f"Cannot start mDNS sockets: {exc}, is this a docker container without "
raise ResolveAPIError( "host network mode?"
"Cannot start mDNS sockets, is this a docker container without " ) from exc
"host network mode?"
)
zc = async_zc_instance.zeroconf
elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance.zeroconf
elif isinstance(zeroconf_instance, Zeroconf):
zc = zeroconf_instance
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
try: try:
info = AsyncServiceInfo(service_type, service_name) info = AsyncServiceInfo(service_type, service_name)
if await info.async_request(zc, int(timeout * 1000)): if await info.async_request(zc, int(timeout * 1000)):
@ -82,8 +73,7 @@ async def _async_zeroconf_get_service_info(
f"Error resolving mDNS {service_name} via mDNS: {exc}" f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc ) from exc
finally: finally:
if async_zc_instance: await zeroconf_manager.async_close()
await async_zc_instance.async_close()
return info return info
@ -92,13 +82,13 @@ async def _async_resolve_host_zeroconf(
port: int, port: int,
*, *,
timeout: float = 3.0, timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]: ) -> list[AddrInfo]:
service_name = f"{host}.{SERVICE_TYPE}" service_name = f"{host}.{SERVICE_TYPE}"
_LOGGER.debug("Resolving host %s via mDNS", service_name) _LOGGER.debug("Resolving host %s via mDNS", service_name)
info = await _async_zeroconf_get_service_info( info = await _async_zeroconf_get_service_info(
zeroconf_instance, SERVICE_TYPE, service_name, timeout zeroconf_manager or ZeroconfManager(), SERVICE_TYPE, service_name, timeout
) )
if info is None: if info is None:
@ -107,7 +97,7 @@ async def _async_resolve_host_zeroconf(
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
for ip_address in info.ip_addresses_by_version(IPVersion.All): for ip_address in info.ip_addresses_by_version(IPVersion.All):
is_ipv6 = ip_address.version == 6 is_ipv6 = ip_address.version == 6
sockaddr: Sockaddr sockaddr: IPv6Sockaddr | IPv4Sockaddr
if is_ipv6: if is_ipv6:
sockaddr = IPv6Sockaddr( sockaddr = IPv6Sockaddr(
address=str(ip_address), address=str(ip_address),
@ -143,7 +133,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
for family, type_, proto, _, raw in res: for family, type_, proto, _, raw in res:
sockaddr: Sockaddr sockaddr: IPv4Sockaddr | IPv6Sockaddr
if family == socket.AF_INET: if family == socket.AF_INET:
raw = cast(tuple[str, int], raw) raw = cast(tuple[str, int], raw)
address, port = raw address, port = raw
@ -197,17 +187,17 @@ def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
async def async_resolve_host( async def async_resolve_host(
host: str, host: str,
port: int, port: int,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_manager: ZeroconfManager | None = None,
) -> AddrInfo: ) -> AddrInfo:
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
zc_error = None zc_error = None
if "." not in host or host.endswith(".local"): if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0] name = host.partition(".")[0]
try: try:
addrs.extend( addrs.extend(
await _async_resolve_host_zeroconf( await _async_resolve_host_zeroconf(
name, port, zeroconf_instance=zeroconf_instance name, port, zeroconf_manager=zeroconf_manager
) )
) )
except APIConnectionError as err: except APIConnectionError as err:

View File

@ -7,8 +7,6 @@ import logging
import sys import sys
from datetime import datetime from datetime import datetime
from zeroconf.asyncio import AsyncZeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient from .client import APIClient
from .log_runner import async_run from .log_runner import async_run
@ -29,14 +27,11 @@ async def main(argv: list[str]) -> None:
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
aiozc = AsyncZeroconf()
cli = APIClient( cli = APIClient(
args.address, args.address,
args.port, args.port,
args.password or "", args.password or "",
noise_psk=args.noise_psk, noise_psk=args.noise_psk,
zeroconf_instance=aiozc.zeroconf,
keepalive=10, keepalive=10,
) )
@ -46,12 +41,10 @@ async def main(argv: list[str]) -> None:
text = message.decode("utf8", "backslashreplace") text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc) stop = await async_run(cli, on_log)
try: try:
while True: await asyncio.Event().wait()
await asyncio.sleep(60)
finally: finally:
await aiozc.async_close()
await stop() await stop()

View File

@ -47,21 +47,16 @@ async def async_run(
) -> None: ) -> None:
_LOGGER.warning("Disconnected from API") _LOGGER.warning("Disconnected from API")
passed_in_zeroconf = aio_zeroconf_instance is not None
aiozc = aio_zeroconf_instance or AsyncZeroconf()
logic = ReconnectLogic( logic = ReconnectLogic(
client=cli, client=cli,
on_connect=on_connect, on_connect=on_connect,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
zeroconf_instance=aiozc.zeroconf, zeroconf_instance=aio_zeroconf_instance,
name=name, name=name,
) )
await logic.start() await logic.start()
async def _stop() -> None: async def _stop() -> None:
if not passed_in_zeroconf:
await aiozc.async_close()
await logic.stop() await logic.stop()
await cli.disconnect() await cli.disconnect()

View File

@ -19,6 +19,8 @@ from .core import (
RequiresEncryptionAPIError, RequiresEncryptionAPIError,
UnhandledAPIConnectionError, UnhandledAPIConnectionError,
) )
from .util import address_is_local
from .zeroconf import ZeroconfInstanceType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -62,7 +64,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
client: APIClient, client: APIClient,
on_connect: Callable[[], Awaitable[None]], on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[bool], Awaitable[None]], on_disconnect: Callable[[bool], Awaitable[None]],
zeroconf_instance: zeroconf.Zeroconf, zeroconf_instance: ZeroconfInstanceType | None = None,
name: str | None = None, name: str | None = None,
on_connect_error: Callable[[Exception], Awaitable[None]] | None = None, on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
) -> None: ) -> None:
@ -74,21 +76,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
""" """
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._cli = client self._cli = client
self.name: str | None self.name: str | None = None
if client.address.endswith(".local"): if name:
self.name = client.address[:-6]
self._log_name = self.name
elif name:
self.name = name self.name = name
self._log_name = f"{name} @ {self._cli.address}" elif address_is_local(client.address):
self._cli.set_cached_name_if_unset(name) self.name = client.address.partition(".")[0]
else: if self.name:
self.name = None self._cli.set_cached_name_if_unset(self.name)
self._log_name = client.address
self._on_connect_cb = on_connect self._on_connect_cb = on_connect
self._on_disconnect_cb = on_disconnect self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance self._zeroconf_manager = client.zeroconf_manager
if zeroconf_instance is not None:
self._zeroconf_manager.set_instance(zeroconf_instance)
self._ptr_alias: str | None = None self._ptr_alias: str | None = None
self._a_name: str | None = None self._a_name: str | None = None
# Flag to check if the device is connected # Flag to check if the device is connected
@ -116,7 +116,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.info( _LOGGER.info(
"Processing %s disconnect from ESPHome API for %s", "Processing %s disconnect from ESPHome API for %s",
disconnect_type, disconnect_type,
self._log_name, self._cli.log_name,
) )
# Run disconnect hook # Run disconnect hook
@ -172,7 +172,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.log( _LOGGER.log(
level, level,
"Can't connect to ESPHome API for %s: %s (%s)", "Can't connect to ESPHome API for %s: %s (%s)",
self._log_name, self._cli.log_name,
err, err,
type(err).__name__, type(err).__name__,
# Print stacktrace if unhandled # Print stacktrace if unhandled
@ -197,7 +197,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
finish_connect_time = time.perf_counter() finish_connect_time = time.perf_counter()
connect_time = finish_connect_time - start_connect_time connect_time = finish_connect_time - start_connect_time
_LOGGER.info( _LOGGER.info(
"Successfully connected to %s in %0.3fs", self._log_name, connect_time "Successfully connected to %s in %0.3fs", self._cli.log_name, connect_time
) )
self._stop_zc_listen() self._stop_zc_listen()
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING) self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
@ -221,7 +221,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
finish_handshake_time = time.perf_counter() finish_handshake_time = time.perf_counter()
handshake_time = finish_handshake_time - finish_connect_time handshake_time = finish_handshake_time - finish_connect_time
_LOGGER.info( _LOGGER.info(
"Successful handshake with %s in %0.3fs", self._log_name, handshake_time "Successful handshake with %s in %0.3fs", self._cli.log_name, handshake_time
) )
self._async_set_connection_state_while_locked(ReconnectLogicState.READY) self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
await self._on_connect_cb() await self._on_connect_cb()
@ -250,7 +250,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return return
_LOGGER.debug( _LOGGER.debug(
"%s: Cancelling existing connect task, to try again now!", "%s: Cancelling existing connect task, to try again now!",
self._log_name, self._cli.log_name,
) )
self._connect_task.cancel("Scheduling new connect attempt") self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None self._connect_task = None
@ -260,7 +260,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._connect_task = asyncio.create_task( self._connect_task = asyncio.create_task(
self._connect_once_or_reschedule(), self._connect_once_or_reschedule(),
name=f"{self._log_name}: aioesphomeapi connect", name=f"{self._cli.log_name}: aioesphomeapi connect",
) )
def _cancel_connect(self, msg: str) -> None: def _cancel_connect(self, msg: str) -> None:
@ -277,9 +277,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
Must only be called from _call_connect_once Must only be called from _call_connect_once
""" """
_LOGGER.debug("Trying to connect to %s", self._log_name) _LOGGER.debug("Trying to connect to %s", self._cli.log_name)
async with self._connected_lock: async with self._connected_lock:
_LOGGER.debug("Connected lock acquired for %s", self._log_name) _LOGGER.debug("Connected lock acquired for %s", self._cli.log_name)
if ( if (
self._connection_state != ReconnectLogicState.DISCONNECTED self._connection_state != ReconnectLogicState.DISCONNECTED
or self._is_stopped or self._is_stopped
@ -291,9 +291,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
wait_time = int(round(min(1.8**tries, 60.0))) wait_time = int(round(min(1.8**tries, 60.0)))
if tries == 1: if tries == 1:
_LOGGER.info( _LOGGER.info(
"Trying to connect to %s in the background", self._log_name "Trying to connect to %s in the background", self._cli.log_name
) )
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time) _LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time)
if wait_time: if wait_time:
# If we are waiting, start listening for mDNS records # If we are waiting, start listening for mDNS records
self._start_zc_listen() self._start_zc_listen()
@ -311,7 +311,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""Stop the connect logic.""" """Stop the connect logic."""
self._stop_task = asyncio.create_task( self._stop_task = asyncio.create_task(
self.stop(), self.stop(),
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback", name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
) )
self._stop_task.add_done_callback(self._remove_stop_task) self._stop_task.add_done_callback(self._remove_stop_task)
@ -342,6 +342,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
ReconnectLogicState.DISCONNECTED ReconnectLogicState.DISCONNECTED
) )
await self._zeroconf_manager.async_close()
def _start_zc_listen(self) -> None: def _start_zc_listen(self) -> None:
"""Listen for mDNS records. """Listen for mDNS records.
@ -352,14 +354,18 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.debug("Starting zeroconf listener for %s", self.name) _LOGGER.debug("Starting zeroconf listener for %s", self.name)
self._ptr_alias = f"{self.name}._esphomelib._tcp.local." self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
self._a_name = f"{self.name}.local." self._a_name = f"{self.name}.local."
self._zc.async_add_listener(self, None) self._zeroconf_manager.get_async_zeroconf().zeroconf.async_add_listener(
self, None
)
self._zc_listening = True self._zc_listening = True
def _stop_zc_listen(self) -> None: def _stop_zc_listen(self) -> None:
"""Stop listening for zeroconf updates.""" """Stop listening for zeroconf updates."""
if self._zc_listening: if self._zc_listening:
_LOGGER.debug("Removing zeroconf listener for %s", self.name) _LOGGER.debug("Removing zeroconf listener for %s", self.name)
self._zc.async_remove_listener(self) self._zeroconf_manager.get_async_zeroconf().zeroconf.async_remove_listener(
self
)
self._zc_listening = False self._zc_listening = False
def async_update_records( def async_update_records(
@ -389,7 +395,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Tell connection logic to retry connection attempt now (even before connect timer finishes) # Tell connection logic to retry connection attempt now (even before connect timer finishes)
_LOGGER.debug( _LOGGER.debug(
"%s: Triggering connect because of received mDNS record %s", "%s: Triggering connect because of received mDNS record %s",
self._log_name, self._cli.log_name,
record_update.new, record_update.new,
) )
# We can't stop the zeroconf listener here because we are in the middle of # We can't stop the zeroconf listener here because we are in the middle of

View File

@ -24,3 +24,27 @@ def fix_float_single_double_conversion(value: float) -> float:
l10 = math.ceil(math.log10(abs_val)) l10 = math.ceil(math.log10(abs_val))
prec = 7 - l10 prec = 7 - l10
return round(value, prec) return round(value, prec)
def host_is_name_part(address: str) -> bool:
"""Return True if a host is the name part."""
return "." not in address and ":" not in address
def address_is_local(address: str) -> bool:
"""Return True if the address is a local address."""
return address.removesuffix(".").endswith(".local")
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
"""Return a log name for a connection."""
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
preferred_address = resolved_address or address
if (
name
and name != preferred_address
and not preferred_address.startswith(f"{name}.")
):
return f"{name} @ {preferred_address}"
return preferred_address

60
aioesphomeapi/zeroconf.py Normal file
View File

@ -0,0 +1,60 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Union
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf]
_LOGGER = logging.getLogger(__name__)
class ZeroconfManager:
"""Manage the Zeroconf objects.
This class is used to manage the Zeroconf objects. It is used to create
the Zeroconf objects and to close them. It attempts to avoid creating
a Zeroconf object unless one is actually needed.
"""
def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None:
"""Initialize the ZeroconfManager."""
self._created = False
self._aiozc: AsyncZeroconf | None = None
if zeroconf is not None:
self.set_instance(zeroconf)
def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
"""Set the AsyncZeroconf instance."""
if self._aiozc:
if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf:
return
if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc:
self._aiozc = AsyncZeroconf(zc=zc)
return
raise RuntimeError("Zeroconf instance already set to a different instance")
self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc)
def _create_async_zeroconf(self) -> None:
"""Create an AsyncZeroconf instance."""
_LOGGER.debug("Creating new AsyncZeroconf instance")
self._aiozc = AsyncZeroconf()
self._created = True
def get_async_zeroconf(self) -> AsyncZeroconf:
"""Get the AsyncZeroconf instance."""
if not self._aiozc:
self._create_async_zeroconf()
if TYPE_CHECKING:
assert self._aiozc is not None
return self._aiozc
async def async_close(self) -> None:
"""Close the Zeroconf connection."""
if not self._created or not self._aiozc:
return
await self._aiozc.async_close()
self._aiozc = None
self._created = False

View File

@ -4,7 +4,7 @@ import asyncio
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
from google.protobuf import message from google.protobuf import message
from zeroconf import Zeroconf from zeroconf import Zeroconf
@ -27,26 +27,30 @@ PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
def get_mock_zeroconf() -> MagicMock: def get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf) with patch("zeroconf.Zeroconf.start"):
zc = Zeroconf()
zc.close = MagicMock()
return zc
def get_mock_async_zeroconf() -> MagicMock: def get_mock_async_zeroconf() -> AsyncZeroconf:
mock = MagicMock(spec=AsyncZeroconf) aiozc = AsyncZeroconf(zc=get_mock_zeroconf())
mock.zeroconf = get_mock_zeroconf() aiozc.async_close = AsyncMock()
mock.async_close = AsyncMock() return aiozc
return mock
class Estr(str): class Estr(str):
"""A subclassed string.""" """A subclassed string."""
def generate_plaintext_packet(msg: bytes, type_: int) -> bytes: def generate_plaintext_packet(msg: message.Message) -> bytes:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return ( return (
b"\0" b"\0"
+ _cached_varuint_to_bytes(len(msg)) + _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_) + _cached_varuint_to_bytes(type_)
+ msg + bytes_
) )
@ -99,10 +103,7 @@ def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
hello_response.api_version_major = 1 hello_response.api_version_major = 1
hello_response.api_version_minor = 9 hello_response.api_version_minor = 9
hello_response.name = "fake" hello_response.name = "fake"
hello_msg = hello_response.SerializeToString() protocol.data_received(generate_plaintext_packet(hello_response))
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
def send_plaintext_connect_response( def send_plaintext_connect_response(
@ -110,8 +111,4 @@ def send_plaintext_connect_response(
) -> None: ) -> None:
connect_response: message.Message = ConnectResponse() connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password connect_response.invalid_password = invalid_password
connect_msg = connect_response.SerializeToString() protocol.data_received(generate_plaintext_packet(connect_response))
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)

View File

@ -12,8 +12,14 @@ from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient, ConnectionParams from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import connect, send_plaintext_hello from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
@pytest.fixture
def async_zeroconf():
return get_mock_async_zeroconf()
@pytest.fixture @pytest.fixture
@ -42,7 +48,7 @@ def connection_params() -> ConnectionParams:
password=None, password=None,
client_info="Tests client", client_info="Tests client",
keepalive=15.0, keepalive=15.0,
zeroconf_instance=None, zeroconf_manager=ZeroconfManager(),
noise_psk=None, noise_psk=None,
expected_name=None, expected_name=None,
) )

View File

@ -141,6 +141,7 @@ def test_plaintext_frame_helper(
assert type_ == pkt_type assert type_ == pkt_type
assert data == pkt_data assert data == pkt_data
helper.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -18,6 +18,8 @@ from aioesphomeapi.api_pb2 import (
CameraImageResponse, CameraImageResponse,
ClimateCommandRequest, ClimateCommandRequest,
CoverCommandRequest, CoverCommandRequest,
DeviceInfoResponse,
DisconnectResponse,
ExecuteServiceArgument, ExecuteServiceArgument,
ExecuteServiceRequest, ExecuteServiceRequest,
FanCommandRequest, FanCommandRequest,
@ -34,6 +36,7 @@ from aioesphomeapi.api_pb2 import (
) )
from aioesphomeapi.client import APIClient from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import ( from aioesphomeapi.model import (
AlarmControlPanelCommand, AlarmControlPanelCommand,
APIVersion, APIVersion,
@ -680,12 +683,7 @@ async def test_bluetooth_disconnect(
response: message.Message = BluetoothDeviceConnectionResponse( response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False address=1234, connected=False
) )
protocol.data_received( protocol.data_received(generate_plaintext_packet(response))
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse],
)
)
await disconnect_task await disconnect_task
@ -700,12 +698,7 @@ async def test_bluetooth_pair(
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDevicePairingResponse(address=1234) response: message.Message = BluetoothDevicePairingResponse(address=1234)
protocol.data_received( protocol.data_received(generate_plaintext_packet(response))
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse],
)
)
await pair_task await pair_task
@ -720,12 +713,7 @@ async def test_bluetooth_unpair(
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234) response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
protocol.data_received( protocol.data_received(generate_plaintext_packet(response))
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse],
)
)
await unpair_task await unpair_task
@ -740,10 +728,36 @@ async def test_bluetooth_clear_cache(
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234) response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
protocol.data_received( protocol.data_received(generate_plaintext_packet(response))
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse],
)
)
await clear_task await clear_task
@pytest.mark.asyncio
async def test_device_info(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test fetching device info."""
client, connection, transport, protocol = api_client
assert client.log_name == "mydevice.local"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
protocol.data_received(generate_plaintext_packet(response))
device_info = await device_info_task
assert device_info.name == "realname"
assert device_info.friendly_name == "My Device"
assert device_info.has_deep_sleep
assert client.log_name == "realname @ 10.0.0.512"
disconnect_task = asyncio.create_task(client.disconnect())
await asyncio.sleep(0)
response: message.Message = DisconnectResponse()
protocol.data_received(generate_plaintext_packet(response))
await disconnect_task
with pytest.raises(APIConnectionError, match="CLOSED"):
await client.device_info()

View File

@ -5,20 +5,11 @@ from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from zeroconf import DNSCache
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
import aioesphomeapi.host_resolver as hr import aioesphomeapi.host_resolver as hr
from aioesphomeapi.core import APIConnectionError from aioesphomeapi.core import APIConnectionError, ResolveAPIError
from aioesphomeapi.zeroconf import ZeroconfManager
@pytest.fixture
def async_zeroconf():
with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
async_zeroconf = klass.return_value
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
yield async_zeroconf
@pytest.fixture @pytest.fixture
@ -52,7 +43,9 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
] ]
info.async_request = AsyncMock(return_value=True) info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info
), patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
ret = await hr._async_resolve_host_zeroconf("asdf", 6052) ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
info.async_request.assert_called_once() info.async_request.assert_called_once()
@ -61,10 +54,8 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos): async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
async_zeroconf = AsyncZeroconf(zc=MagicMock()) zeroconf_manager = ZeroconfManager()
async_zeroconf.async_close = AsyncMock()
async_zeroconf.zeroconf.cache = DNSCache()
info = MagicMock(auto_spec=AsyncServiceInfo) info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [ info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"), ip_address(b"\n\x00\x00*"),
@ -73,11 +64,10 @@ async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
info.async_request = AsyncMock(return_value=True) info.async_request = AsyncMock(return_value=True)
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
ret = await hr._async_resolve_host_zeroconf( ret = await hr._async_resolve_host_zeroconf(
"asdf", 6052, zeroconf_instance=async_zeroconf "asdf", 6052, zeroconf_manager=zeroconf_manager
) )
info.async_request.assert_called_once() info.async_request.assert_called_once()
async_zeroconf.async_close.assert_not_called()
assert ret == addr_infos assert ret == addr_infos
@ -131,7 +121,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052) ret = await hr.async_resolve_host("example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called() resolve_addr.assert_not_called()
assert ret == addr_infos[0] assert ret == addr_infos[0]
@ -144,7 +134,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052) ret = await hr.async_resolve_host("example.local", 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052) resolve_addr.assert_called_once_with("example.local", 6052)
assert ret == addr_infos[0] assert ret == addr_infos[0]
@ -189,3 +179,40 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc):
proto=6, proto=6,
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052), sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
) )
@pytest.mark.asyncio
async def test_resolve_host_zeroconf_service_info_oserror(
async_zeroconf: AsyncZeroconf, addr_infos
):
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch(
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=OSError("out of buffers"),
), patch(
"aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf
), pytest.raises(
ResolveAPIError, match="out of buffers"
):
await hr._async_resolve_host_zeroconf("asdf", 6052)
@pytest.mark.asyncio
async def test_resolve_host_create_zeroconf_oserror(
async_zeroconf: AsyncZeroconf, addr_infos
):
info = MagicMock(auto_spec=AsyncServiceInfo)
info.ip_addresses_by_version.return_value = [
ip_address(b"\n\x00\x00*"),
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
]
info.async_request = AsyncMock(return_value=True)
with patch(
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
), pytest.raises(ResolveAPIError, match="out of buffers"):
await hr._async_resolve_host_zeroconf("asdf", 6052)

View File

@ -75,21 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
response: message.Message = SubscribeLogsResponse() response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world" response.message = b"Hello world"
protocol.data_received( protocol.data_received(generate_plaintext_packet(response))
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
)
)
assert len(messages) == 1 assert len(messages) == 1
assert messages[0].message == b"Hello world" assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop()) stop_task = asyncio.create_task(stop())
await asyncio.sleep(0) await asyncio.sleep(0)
disconnect_response = DisconnectResponse() disconnect_response = DisconnectResponse()
protocol.data_received( protocol.data_received(generate_plaintext_packet(disconnect_response))
generate_plaintext_packet(
disconnect_response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
)
)
await stop_task await stop_task

View File

@ -41,14 +41,13 @@ async def test_reconnect_logic_name_from_host():
async def on_connect() -> None: async def on_connect() -> None:
pass pass
rl = ReconnectLogic( ReconnectLogic(
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf), zeroconf_instance=MagicMock(spec=AsyncZeroconf),
) )
assert rl._log_name == "mydevice" assert cli.log_name == "mydevice.local"
assert cli._log_name == "mydevice"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -66,15 +65,14 @@ async def test_reconnect_logic_name_from_host_and_set():
async def on_connect() -> None: async def on_connect() -> None:
pass pass
rl = ReconnectLogic( ReconnectLogic(
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(), zeroconf_instance=get_mock_zeroconf(),
name="mydevice", name="mydevice",
) )
assert rl._log_name == "mydevice" assert cli.log_name == "mydevice.local"
assert cli._log_name == "mydevice"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -92,14 +90,13 @@ async def test_reconnect_logic_name_from_address():
async def on_connect() -> None: async def on_connect() -> None:
pass pass
rl = ReconnectLogic( ReconnectLogic(
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(), zeroconf_instance=get_mock_zeroconf(),
) )
assert rl._log_name == "1.2.3.4" assert cli.log_name == "1.2.3.4"
assert cli._log_name == "1.2.3.4"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -117,15 +114,14 @@ async def test_reconnect_logic_name_from_name():
async def on_connect() -> None: async def on_connect() -> None:
pass pass
rl = ReconnectLogic( ReconnectLogic(
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(), zeroconf_instance=get_mock_zeroconf(),
name="mydevice", name="mydevice",
) )
assert rl._log_name == "mydevice @ 1.2.3.4" assert cli.log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -164,8 +160,7 @@ async def test_reconnect_logic_state():
name="mydevice", name="mydevice",
on_connect_error=on_connect_fail, on_connect_error=on_connect_fail,
) )
assert rl._log_name == "mydevice @ 1.2.3.4" assert cli.log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError): with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start() await rl.start()
@ -241,8 +236,7 @@ async def test_reconnect_retry():
name="mydevice", name="mydevice",
on_connect_error=on_connect_fail, on_connect_error=on_connect_fail,
) )
assert rl._log_name == "mydevice @ 1.2.3.4" assert cli.log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError): with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start() await rl.start()
@ -338,8 +332,7 @@ async def test_reconnect_zeroconf(
name="mydevice", name="mydevice",
on_connect_error=AsyncMock(), on_connect_error=AsyncMock(),
) )
assert rl._log_name == "mydevice @ 1.2.3.4" assert cli.log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
async def slow_connect_fail(*args, **kwargs): async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10) await asyncio.sleep(10)

52
tests/test_zeroconf.py Normal file
View File

@ -0,0 +1,52 @@
from __future__ import annotations
from unittest.mock import patch
import pytest
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import get_mock_async_zeroconf
@pytest.mark.asyncio
async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf):
"""Test that the passed in instance is not closed."""
manager = ZeroconfManager()
manager.set_instance(async_zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0
@pytest.mark.asyncio
async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf):
"""Test that the passed in instance is not closed."""
manager = ZeroconfManager()
manager.set_instance(async_zeroconf.zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0
@pytest.mark.asyncio
async def test_closes_created_instance(async_zeroconf: AsyncZeroconf):
"""Test that the created instance is closed."""
with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
manager = ZeroconfManager()
assert manager.get_async_zeroconf() is async_zeroconf
await manager.async_close()
assert async_zeroconf.async_close.call_count == 1
@pytest.mark.asyncio
async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf):
"""Test runtime error is raised on multiple instances."""
manager = ZeroconfManager(async_zeroconf)
new_instance = get_mock_async_zeroconf()
with pytest.raises(RuntimeError):
manager.set_instance(new_instance)
manager.set_instance(async_zeroconf)
manager.set_instance(async_zeroconf.zeroconf)
manager.set_instance(async_zeroconf)
await manager.async_close()
assert async_zeroconf.async_close.call_count == 0