diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 25d9200..794357d 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -20,6 +20,8 @@ cdef class APIFrameHelper: cdef str _log_name cdef object _debug_enabled + cpdef set_log_name(self, str log_name) + @cython.locals(original_pos="unsigned int", new_pos="unsigned int") cdef bytes _read_exactly(self, int length) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 2520ab1..cf3860c 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -62,6 +62,10 @@ class APIFrameHelper: self._log_name = log_name self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) + def set_log_name(self, log_name: str) -> None: + """Set the log name.""" + self._log_name = log_name + def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None: if not self._ready_future.done(): self._ready_future.set_exception(exc) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index df370e1..057e616 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -111,7 +111,6 @@ from .core import ( UnhandledAPIConnectionError, to_human_readable_address, ) -from .host_resolver import ZeroconfInstanceType from .model import ( AlarmControlPanelCommand, AlarmControlPanelEntityState, @@ -177,6 +176,8 @@ from .model import ( VoiceAssistantCommand, VoiceAssistantEventType, ) +from .util import build_log_name +from .zeroconf import ZeroconfInstanceType, ZeroconfManager _LOGGER = logging.getLogger(__name__) @@ -258,10 +259,10 @@ class APIClient: __slots__ = ( "_params", "_connection", - "_cached_name", + "cached_name", "_background_tasks", "_loop", - "_log_name", + "log_name", ) def __init__( @@ -272,7 +273,7 @@ class APIClient: *, client_info: str = "aioesphomeapi", keepalive: float = KEEP_ALIVE_FREQUENCY, - zeroconf_instance: ZeroconfInstanceType = None, + zeroconf_instance: ZeroconfInstanceType | None = None, noise_psk: str | None = None, expected_name: str | None = None, ) -> None: @@ -297,17 +298,21 @@ class APIClient: password=password, client_info=client_info, keepalive=keepalive, - zeroconf_instance=zeroconf_instance, + zeroconf_manager=ZeroconfManager(zeroconf_instance), # treat empty '' psk string as missing (like password) noise_psk=_stringify_or_none(noise_psk) or None, expected_name=_stringify_or_none(expected_name) or None, ) self._connection: APIConnection | None = None - self._cached_name: str | None = None + self.cached_name: str | None = None self._background_tasks: set[asyncio.Task[Any]] = set() self._loop = asyncio.get_event_loop() self._set_log_name() + @property + def zeroconf_manager(self) -> ZeroconfManager: + return self._params.zeroconf_manager + @property def expected_name(self) -> str | None: return self._params.expected_name @@ -320,26 +325,23 @@ class APIClient: def address(self) -> str: return self._params.address - def _get_log_name(self) -> str: - """Get the log name of the device.""" - address = self.address - address_is_host = address.endswith(".local") - if self._cached_name is not None: - if address_is_host: - return self._cached_name - return f"{self._cached_name} @ {address}" - if address_is_host: - return address[:-6] - return address - def _set_log_name(self) -> None: """Set the log name of the device.""" - self._log_name = self._get_log_name() + resolved_address: str | None = None + if self._connection and self._connection.resolved_addr_info: + resolved_address = self._connection.resolved_addr_info.sockaddr.address + self.log_name = build_log_name( + self.cached_name, + self.address, + resolved_address, + ) + if self._connection: + self._connection.set_log_name(self.log_name) def set_cached_name_if_unset(self, name: str) -> None: """Set the cached name of the device if not set.""" - if not self._cached_name: - self._cached_name = name + if not self.cached_name: + self.cached_name = name self._set_log_name() async def connect( @@ -357,7 +359,7 @@ class APIClient: ) -> None: """Start connecting to the device.""" if self._connection is not None: - raise APIConnectionError(f"Already connected to {self._log_name}!") + raise APIConnectionError(f"Already connected to {self.log_name}!") async def _on_stop(expected_disconnect: bool) -> None: # Hook into on_stop handler to clear connection when stopped @@ -365,9 +367,7 @@ class APIClient: if on_stop is not None: await on_stop(expected_disconnect) - self._connection = APIConnection( - self._params, _on_stop, log_name=self._log_name - ) + self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name) try: await self._connection.start_connection() @@ -377,8 +377,11 @@ class APIClient: except Exception as e: self._connection = None raise UnhandledAPIConnectionError( - f"Unexpected error while connecting to {self._log_name}: {e}" + f"Unexpected error while connecting to {self.log_name}: {e}" ) from e + # If we resolved the address, we should set the log name now + if self._connection.resolved_addr_info: + self._set_log_name() async def finish_connection( self, @@ -394,8 +397,10 @@ class APIClient: except Exception as e: self._connection = None raise UnhandledAPIConnectionError( - f"Unexpected error while connecting to {self._log_name}: {e}" + f"Unexpected error while connecting to {self.log_name}: {e}" ) from e + if received_name := self._connection.received_name: + self._set_name_from_device(received_name) async def disconnect(self, force: bool = False) -> None: if self._connection is None: @@ -408,10 +413,10 @@ class APIClient: def _check_authenticated(self) -> None: connection = self._connection if not connection: - raise APIConnectionError(f"Not connected to {self._log_name}!") + raise APIConnectionError(f"Not connected to {self.log_name}!") if not connection.is_connected: raise APIConnectionError( - f"Authenticated connection not ready yet for {self._log_name}; " + f"Authenticated connection not ready yet for {self.log_name}; " f"current state is {connection.connection_state}!" ) @@ -423,11 +428,14 @@ class APIClient: DeviceInfoRequest(), DeviceInfoResponse ) info = DeviceInfo.from_pb(resp) - self._cached_name = info.name - connection.set_log_name(self._log_name) - self._set_log_name() + self._set_name_from_device(info.name) return info + def _set_name_from_device(self, name: str) -> None: + """Set the name from a DeviceInfo message.""" + self.cached_name = name + self._set_log_name() + async def list_entities_services( self, ) -> tuple[list[EntityInfo], list[UserService]]: diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 82600d7..31c7997 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -68,6 +68,8 @@ cdef class APIConnection: cdef public bint is_connected cdef bint _handshake_complete cdef object _debug_enabled + cdef public str received_name + cdef public object resolved_addr_info cpdef send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index bfb0eb6..c2a5e21 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -49,6 +49,7 @@ from .core import ( TimeoutAPIError, ) from .model import APIVersion +from .zeroconf import ZeroconfManager if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout @@ -111,7 +112,7 @@ class ConnectionParams: password: str | None client_info: str keepalive: float - zeroconf_instance: hr.ZeroconfInstanceType + zeroconf_manager: ZeroconfManager noise_psk: str | None expected_name: str | None @@ -159,6 +160,8 @@ class APIConnection: "is_connected", "_handshake_complete", "_debug_enabled", + "received_name", + "resolved_addr_info", ) def __init__( @@ -201,10 +204,14 @@ class APIConnection: self.is_connected = False self._handshake_complete = False self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) + self.received_name: str = "" + self.resolved_addr_info: hr.AddrInfo | None = None def set_log_name(self, name: str) -> None: """Set the friendly log name for this connection.""" self.log_name = name + if self._frame_helper is not None: + self._frame_helper.set_log_name(name) def _cleanup(self) -> None: """Clean up all resources that have been allocated. @@ -276,7 +283,7 @@ class APIConnection: return await hr.async_resolve_host( self._params.address, self._params.port, - self._params.zeroconf_instance, + self._params.zeroconf_manager, ) except asyncio_TimeoutError as err: raise ResolveAPIError( @@ -427,17 +434,16 @@ class APIConnection: self.api_version = api_version expected_name = self._params.expected_name - received_name = resp.name - if ( - expected_name is not None - and received_name != "" - and received_name != expected_name - ): - raise BadNameAPIError( - f"Expected '{expected_name}' but server sent " - f"a different name: '{received_name}'", - received_name, - ) + if received_name := resp.name: + if expected_name is not None and received_name != expected_name: + raise BadNameAPIError( + f"Expected '{expected_name}' but server sent " + f"a different name: '{received_name}'", + received_name, + ) + + self.received_name = received_name + self.set_log_name(received_name) def _async_schedule_keep_alive(self, now: _float) -> None: """Start the keep alive task.""" @@ -506,8 +512,8 @@ class APIConnection: async def _do_connect(self) -> None: """Do the actual connect process.""" in_do_connect.set(True) - addr = await self._connect_resolve_host() - await self._connect_socket_connect(addr) + self.resolved_addr_info = await self._connect_resolve_host() + await self._connect_socket_connect(self.resolved_addr_info) async def start_connection(self) -> None: """Start the connection process. diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 97df657..54e06ad 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -6,35 +6,38 @@ import logging import socket from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address -from typing import Union, cast +from typing import cast -from zeroconf import IPVersion, Zeroconf -from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf +from zeroconf import IPVersion +from zeroconf.asyncio import AsyncServiceInfo from .core import APIConnectionError, ResolveAPIError +from .util import address_is_local, host_is_name_part +from .zeroconf import ZeroconfManager _LOGGER = logging.getLogger(__name__) -ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None] SERVICE_TYPE = "_esphomelib._tcp.local." @dataclass(frozen=True) class Sockaddr: - pass + """Base socket address.""" + + address: str + port: int @dataclass(frozen=True) class IPv4Sockaddr(Sockaddr): - address: str - port: int + """IPv4 socket address.""" @dataclass(frozen=True) class IPv6Sockaddr(Sockaddr): - address: str - port: int + """IPv6 socket address.""" + flowinfo: int scope_id: int @@ -44,35 +47,23 @@ class AddrInfo: family: int type: int proto: int - sockaddr: Sockaddr + sockaddr: IPv4Sockaddr | IPv6Sockaddr async def _async_zeroconf_get_service_info( - zeroconf_instance: ZeroconfInstanceType, + zeroconf_manager: ZeroconfManager, service_type: str, service_name: str, timeout: float, ) -> AsyncServiceInfo | None: # Use or create zeroconf instance, ensure it's an AsyncZeroconf - async_zc_instance: AsyncZeroconf | None = None - if zeroconf_instance is None: - try: - async_zc_instance = AsyncZeroconf() - except Exception: - raise ResolveAPIError( - "Cannot start mDNS sockets, is this a docker container without " - "host network mode?" - ) - zc = async_zc_instance.zeroconf - elif isinstance(zeroconf_instance, AsyncZeroconf): - zc = zeroconf_instance.zeroconf - elif isinstance(zeroconf_instance, Zeroconf): - zc = zeroconf_instance - else: - raise ValueError( - f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" - ) - + try: + zc = zeroconf_manager.get_async_zeroconf().zeroconf + except Exception as exc: + raise ResolveAPIError( + f"Cannot start mDNS sockets: {exc}, is this a docker container without " + "host network mode?" + ) from exc try: info = AsyncServiceInfo(service_type, service_name) if await info.async_request(zc, int(timeout * 1000)): @@ -82,8 +73,7 @@ async def _async_zeroconf_get_service_info( f"Error resolving mDNS {service_name} via mDNS: {exc}" ) from exc finally: - if async_zc_instance: - await async_zc_instance.async_close() + await zeroconf_manager.async_close() return info @@ -92,13 +82,13 @@ async def _async_resolve_host_zeroconf( port: int, *, timeout: float = 3.0, - zeroconf_instance: ZeroconfInstanceType = None, + zeroconf_manager: ZeroconfManager | None = None, ) -> list[AddrInfo]: service_name = f"{host}.{SERVICE_TYPE}" _LOGGER.debug("Resolving host %s via mDNS", service_name) info = await _async_zeroconf_get_service_info( - zeroconf_instance, SERVICE_TYPE, service_name, timeout + zeroconf_manager or ZeroconfManager(), SERVICE_TYPE, service_name, timeout ) if info is None: @@ -107,7 +97,7 @@ async def _async_resolve_host_zeroconf( addrs: list[AddrInfo] = [] for ip_address in info.ip_addresses_by_version(IPVersion.All): is_ipv6 = ip_address.version == 6 - sockaddr: Sockaddr + sockaddr: IPv6Sockaddr | IPv4Sockaddr if is_ipv6: sockaddr = IPv6Sockaddr( address=str(ip_address), @@ -143,7 +133,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo addrs: list[AddrInfo] = [] for family, type_, proto, _, raw in res: - sockaddr: Sockaddr + sockaddr: IPv4Sockaddr | IPv6Sockaddr if family == socket.AF_INET: raw = cast(tuple[str, int], raw) address, port = raw @@ -197,17 +187,17 @@ def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]: async def async_resolve_host( host: str, port: int, - zeroconf_instance: ZeroconfInstanceType = None, + zeroconf_manager: ZeroconfManager | None = None, ) -> AddrInfo: addrs: list[AddrInfo] = [] zc_error = None - if "." not in host or host.endswith(".local"): + if host_is_name_part(host) or address_is_local(host): name = host.partition(".")[0] try: addrs.extend( await _async_resolve_host_zeroconf( - name, port, zeroconf_instance=zeroconf_instance + name, port, zeroconf_manager=zeroconf_manager ) ) except APIConnectionError as err: diff --git a/aioesphomeapi/log_reader.py b/aioesphomeapi/log_reader.py index 11e9ebe..507a499 100644 --- a/aioesphomeapi/log_reader.py +++ b/aioesphomeapi/log_reader.py @@ -7,8 +7,6 @@ import logging import sys from datetime import datetime -from zeroconf.asyncio import AsyncZeroconf - from .api_pb2 import SubscribeLogsResponse # type: ignore from .client import APIClient from .log_runner import async_run @@ -29,14 +27,11 @@ async def main(argv: list[str]) -> None: datefmt="%Y-%m-%d %H:%M:%S", ) - aiozc = AsyncZeroconf() - cli = APIClient( args.address, args.port, args.password or "", noise_psk=args.noise_psk, - zeroconf_instance=aiozc.zeroconf, keepalive=10, ) @@ -46,12 +41,10 @@ async def main(argv: list[str]) -> None: text = message.decode("utf8", "backslashreplace") print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") - stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc) + stop = await async_run(cli, on_log) try: - while True: - await asyncio.sleep(60) + await asyncio.Event().wait() finally: - await aiozc.async_close() await stop() diff --git a/aioesphomeapi/log_runner.py b/aioesphomeapi/log_runner.py index 25a6c10..4205942 100644 --- a/aioesphomeapi/log_runner.py +++ b/aioesphomeapi/log_runner.py @@ -47,21 +47,16 @@ async def async_run( ) -> None: _LOGGER.warning("Disconnected from API") - passed_in_zeroconf = aio_zeroconf_instance is not None - aiozc = aio_zeroconf_instance or AsyncZeroconf() - logic = ReconnectLogic( client=cli, on_connect=on_connect, on_disconnect=on_disconnect, - zeroconf_instance=aiozc.zeroconf, + zeroconf_instance=aio_zeroconf_instance, name=name, ) await logic.start() async def _stop() -> None: - if not passed_in_zeroconf: - await aiozc.async_close() await logic.stop() await cli.disconnect() diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index cc972fd..1214eff 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -19,6 +19,8 @@ from .core import ( RequiresEncryptionAPIError, UnhandledAPIConnectionError, ) +from .util import address_is_local +from .zeroconf import ZeroconfInstanceType _LOGGER = logging.getLogger(__name__) @@ -62,7 +64,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): client: APIClient, on_connect: Callable[[], Awaitable[None]], on_disconnect: Callable[[bool], Awaitable[None]], - zeroconf_instance: zeroconf.Zeroconf, + zeroconf_instance: ZeroconfInstanceType | None = None, name: str | None = None, on_connect_error: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: @@ -74,21 +76,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): """ self.loop = asyncio.get_event_loop() self._cli = client - self.name: str | None - if client.address.endswith(".local"): - self.name = client.address[:-6] - self._log_name = self.name - elif name: + self.name: str | None = None + if name: self.name = name - self._log_name = f"{name} @ {self._cli.address}" - self._cli.set_cached_name_if_unset(name) - else: - self.name = None - self._log_name = client.address + elif address_is_local(client.address): + self.name = client.address.partition(".")[0] + if self.name: + self._cli.set_cached_name_if_unset(self.name) self._on_connect_cb = on_connect self._on_disconnect_cb = on_disconnect self._on_connect_error_cb = on_connect_error - self._zc = zeroconf_instance + self._zeroconf_manager = client.zeroconf_manager + if zeroconf_instance is not None: + self._zeroconf_manager.set_instance(zeroconf_instance) self._ptr_alias: str | None = None self._a_name: str | None = None # Flag to check if the device is connected @@ -116,7 +116,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): _LOGGER.info( "Processing %s disconnect from ESPHome API for %s", disconnect_type, - self._log_name, + self._cli.log_name, ) # Run disconnect hook @@ -172,7 +172,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): _LOGGER.log( level, "Can't connect to ESPHome API for %s: %s (%s)", - self._log_name, + self._cli.log_name, err, type(err).__name__, # Print stacktrace if unhandled @@ -197,7 +197,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): finish_connect_time = time.perf_counter() connect_time = finish_connect_time - start_connect_time _LOGGER.info( - "Successfully connected to %s in %0.3fs", self._log_name, connect_time + "Successfully connected to %s in %0.3fs", self._cli.log_name, connect_time ) self._stop_zc_listen() self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING) @@ -221,7 +221,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): finish_handshake_time = time.perf_counter() handshake_time = finish_handshake_time - finish_connect_time _LOGGER.info( - "Successful handshake with %s in %0.3fs", self._log_name, handshake_time + "Successful handshake with %s in %0.3fs", self._cli.log_name, handshake_time ) self._async_set_connection_state_while_locked(ReconnectLogicState.READY) await self._on_connect_cb() @@ -250,7 +250,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): return _LOGGER.debug( "%s: Cancelling existing connect task, to try again now!", - self._log_name, + self._cli.log_name, ) self._connect_task.cancel("Scheduling new connect attempt") self._connect_task = None @@ -260,7 +260,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): self._connect_task = asyncio.create_task( self._connect_once_or_reschedule(), - name=f"{self._log_name}: aioesphomeapi connect", + name=f"{self._cli.log_name}: aioesphomeapi connect", ) def _cancel_connect(self, msg: str) -> None: @@ -277,9 +277,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): Must only be called from _call_connect_once """ - _LOGGER.debug("Trying to connect to %s", self._log_name) + _LOGGER.debug("Trying to connect to %s", self._cli.log_name) async with self._connected_lock: - _LOGGER.debug("Connected lock acquired for %s", self._log_name) + _LOGGER.debug("Connected lock acquired for %s", self._cli.log_name) if ( self._connection_state != ReconnectLogicState.DISCONNECTED or self._is_stopped @@ -291,9 +291,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): wait_time = int(round(min(1.8**tries, 60.0))) if tries == 1: _LOGGER.info( - "Trying to connect to %s in the background", self._log_name + "Trying to connect to %s in the background", self._cli.log_name ) - _LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time) + _LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time) if wait_time: # If we are waiting, start listening for mDNS records self._start_zc_listen() @@ -311,7 +311,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): """Stop the connect logic.""" self._stop_task = asyncio.create_task( self.stop(), - name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback", + name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback", ) self._stop_task.add_done_callback(self._remove_stop_task) @@ -342,6 +342,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): ReconnectLogicState.DISCONNECTED ) + await self._zeroconf_manager.async_close() + def _start_zc_listen(self) -> None: """Listen for mDNS records. @@ -352,14 +354,18 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): _LOGGER.debug("Starting zeroconf listener for %s", self.name) self._ptr_alias = f"{self.name}._esphomelib._tcp.local." self._a_name = f"{self.name}.local." - self._zc.async_add_listener(self, None) + self._zeroconf_manager.get_async_zeroconf().zeroconf.async_add_listener( + self, None + ) self._zc_listening = True def _stop_zc_listen(self) -> None: """Stop listening for zeroconf updates.""" if self._zc_listening: _LOGGER.debug("Removing zeroconf listener for %s", self.name) - self._zc.async_remove_listener(self) + self._zeroconf_manager.get_async_zeroconf().zeroconf.async_remove_listener( + self + ) self._zc_listening = False def async_update_records( @@ -389,7 +395,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # Tell connection logic to retry connection attempt now (even before connect timer finishes) _LOGGER.debug( "%s: Triggering connect because of received mDNS record %s", - self._log_name, + self._cli.log_name, record_update.new, ) # We can't stop the zeroconf listener here because we are in the middle of diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index 2f87f21..ae226bb 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -24,3 +24,27 @@ def fix_float_single_double_conversion(value: float) -> float: l10 = math.ceil(math.log10(abs_val)) prec = 7 - l10 return round(value, prec) + + +def host_is_name_part(address: str) -> bool: + """Return True if a host is the name part.""" + return "." not in address and ":" not in address + + +def address_is_local(address: str) -> bool: + """Return True if the address is a local address.""" + return address.removesuffix(".").endswith(".local") + + +def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str: + """Return a log name for a connection.""" + if not name and address_is_local(address) or host_is_name_part(address): + name = address.partition(".")[0] + preferred_address = resolved_address or address + if ( + name + and name != preferred_address + and not preferred_address.startswith(f"{name}.") + ): + return f"{name} @ {preferred_address}" + return preferred_address diff --git a/aioesphomeapi/zeroconf.py b/aioesphomeapi/zeroconf.py new file mode 100644 index 0000000..48f19fc --- /dev/null +++ b/aioesphomeapi/zeroconf.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Union + +from zeroconf import Zeroconf +from zeroconf.asyncio import AsyncZeroconf + +ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf] + +_LOGGER = logging.getLogger(__name__) + + +class ZeroconfManager: + """Manage the Zeroconf objects. + + This class is used to manage the Zeroconf objects. It is used to create + the Zeroconf objects and to close them. It attempts to avoid creating + a Zeroconf object unless one is actually needed. + """ + + def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None: + """Initialize the ZeroconfManager.""" + self._created = False + self._aiozc: AsyncZeroconf | None = None + if zeroconf is not None: + self.set_instance(zeroconf) + + def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None: + """Set the AsyncZeroconf instance.""" + if self._aiozc: + if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf: + return + if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc: + self._aiozc = AsyncZeroconf(zc=zc) + return + raise RuntimeError("Zeroconf instance already set to a different instance") + self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc) + + def _create_async_zeroconf(self) -> None: + """Create an AsyncZeroconf instance.""" + _LOGGER.debug("Creating new AsyncZeroconf instance") + self._aiozc = AsyncZeroconf() + self._created = True + + def get_async_zeroconf(self) -> AsyncZeroconf: + """Get the AsyncZeroconf instance.""" + if not self._aiozc: + self._create_async_zeroconf() + if TYPE_CHECKING: + assert self._aiozc is not None + return self._aiozc + + async def async_close(self) -> None: + """Close the Zeroconf connection.""" + if not self._created or not self._aiozc: + return + await self._aiozc.async_close() + self._aiozc = None + self._created = False diff --git a/tests/common.py b/tests/common.py index 414f7d2..248659f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,7 +4,7 @@ import asyncio import time from datetime import datetime, timezone from functools import partial -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from google.protobuf import message from zeroconf import Zeroconf @@ -27,26 +27,30 @@ PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} def get_mock_zeroconf() -> MagicMock: - return MagicMock(spec=Zeroconf) + with patch("zeroconf.Zeroconf.start"): + zc = Zeroconf() + zc.close = MagicMock() + return zc -def get_mock_async_zeroconf() -> MagicMock: - mock = MagicMock(spec=AsyncZeroconf) - mock.zeroconf = get_mock_zeroconf() - mock.async_close = AsyncMock() - return mock +def get_mock_async_zeroconf() -> AsyncZeroconf: + aiozc = AsyncZeroconf(zc=get_mock_zeroconf()) + aiozc.async_close = AsyncMock() + return aiozc class Estr(str): """A subclassed string.""" -def generate_plaintext_packet(msg: bytes, type_: int) -> bytes: +def generate_plaintext_packet(msg: message.Message) -> bytes: + type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__] + bytes_ = msg.SerializeToString() return ( b"\0" - + _cached_varuint_to_bytes(len(msg)) + + _cached_varuint_to_bytes(len(bytes_)) + _cached_varuint_to_bytes(type_) - + msg + + bytes_ ) @@ -99,10 +103,7 @@ def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None: hello_response.api_version_major = 1 hello_response.api_version_minor = 9 hello_response.name = "fake" - hello_msg = hello_response.SerializeToString() - protocol.data_received( - generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse]) - ) + protocol.data_received(generate_plaintext_packet(hello_response)) def send_plaintext_connect_response( @@ -110,8 +111,4 @@ def send_plaintext_connect_response( ) -> None: connect_response: message.Message = ConnectResponse() connect_response.invalid_password = invalid_password - connect_msg = connect_response.SerializeToString() - - protocol.data_received( - generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]) - ) + protocol.data_received(generate_plaintext_packet(connect_response)) diff --git a/tests/conftest.py b/tests/conftest.py index 1a0ebeb..1ef58bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,8 +12,14 @@ from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi.client import APIClient, ConnectionParams from aioesphomeapi.connection import APIConnection from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr +from aioesphomeapi.zeroconf import ZeroconfManager -from .common import connect, send_plaintext_hello +from .common import connect, get_mock_async_zeroconf, send_plaintext_hello + + +@pytest.fixture +def async_zeroconf(): + return get_mock_async_zeroconf() @pytest.fixture @@ -42,7 +48,7 @@ def connection_params() -> ConnectionParams: password=None, client_info="Tests client", keepalive=15.0, - zeroconf_instance=None, + zeroconf_manager=ZeroconfManager(), noise_psk=None, expected_name=None, ) diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index eb0e026..adcd221 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -141,6 +141,7 @@ def test_plaintext_frame_helper( assert type_ == pkt_type assert data == pkt_data + helper.close() @pytest.mark.parametrize( diff --git a/tests/test_client.py b/tests/test_client.py index 899fb10..82e4c7b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -18,6 +18,8 @@ from aioesphomeapi.api_pb2 import ( CameraImageResponse, ClimateCommandRequest, CoverCommandRequest, + DeviceInfoResponse, + DisconnectResponse, ExecuteServiceArgument, ExecuteServiceRequest, FanCommandRequest, @@ -34,6 +36,7 @@ from aioesphomeapi.api_pb2 import ( ) from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection +from aioesphomeapi.core import APIConnectionError from aioesphomeapi.model import ( AlarmControlPanelCommand, APIVersion, @@ -680,12 +683,7 @@ async def test_bluetooth_disconnect( response: message.Message = BluetoothDeviceConnectionResponse( address=1234, connected=False ) - protocol.data_received( - generate_plaintext_packet( - response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(response)) await disconnect_task @@ -700,12 +698,7 @@ async def test_bluetooth_pair( pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDevicePairingResponse(address=1234) - protocol.data_received( - generate_plaintext_packet( - response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(response)) await pair_task @@ -720,12 +713,7 @@ async def test_bluetooth_unpair( unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDeviceUnpairingResponse(address=1234) - protocol.data_received( - generate_plaintext_packet( - response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(response)) await unpair_task @@ -740,10 +728,36 @@ async def test_bluetooth_clear_cache( clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) await asyncio.sleep(0) response: message.Message = BluetoothDeviceClearCacheResponse(address=1234) - protocol.data_received( - generate_plaintext_packet( - response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(response)) await clear_task + + +@pytest.mark.asyncio +async def test_device_info( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test fetching device info.""" + client, connection, transport, protocol = api_client + assert client.log_name == "mydevice.local" + device_info_task = asyncio.create_task(client.device_info()) + await asyncio.sleep(0) + response: message.Message = DeviceInfoResponse( + name="realname", + friendly_name="My Device", + has_deep_sleep=True, + ) + protocol.data_received(generate_plaintext_packet(response)) + device_info = await device_info_task + assert device_info.name == "realname" + assert device_info.friendly_name == "My Device" + assert device_info.has_deep_sleep + assert client.log_name == "realname @ 10.0.0.512" + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0) + response: message.Message = DisconnectResponse() + protocol.data_received(generate_plaintext_packet(response)) + await disconnect_task + with pytest.raises(APIConnectionError, match="CLOSED"): + await client.device_info() diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 66eeb93..d2b9287 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -5,20 +5,11 @@ from ipaddress import ip_address from unittest.mock import AsyncMock, MagicMock, patch import pytest -from zeroconf import DNSCache from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf import aioesphomeapi.host_resolver as hr -from aioesphomeapi.core import APIConnectionError - - -@pytest.fixture -def async_zeroconf(): - with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass: - async_zeroconf = klass.return_value - async_zeroconf.async_close = AsyncMock() - async_zeroconf.zeroconf.cache = DNSCache() - yield async_zeroconf +from aioesphomeapi.core import APIConnectionError, ResolveAPIError +from aioesphomeapi.zeroconf import ZeroconfManager @pytest.fixture @@ -52,7 +43,9 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ] info.async_request = AsyncMock(return_value=True) - with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): + with patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info + ), patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf): ret = await hr._async_resolve_host_zeroconf("asdf", 6052) info.async_request.assert_called_once() @@ -61,10 +54,8 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): @pytest.mark.asyncio -async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos): - async_zeroconf = AsyncZeroconf(zc=MagicMock()) - async_zeroconf.async_close = AsyncMock() - async_zeroconf.zeroconf.cache = DNSCache() +async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf): + zeroconf_manager = ZeroconfManager() info = MagicMock(auto_spec=AsyncServiceInfo) info.ip_addresses_by_version.return_value = [ ip_address(b"\n\x00\x00*"), @@ -73,11 +64,10 @@ async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos): info.async_request = AsyncMock(return_value=True) with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info): ret = await hr._async_resolve_host_zeroconf( - "asdf", 6052, zeroconf_instance=async_zeroconf + "asdf", 6052, zeroconf_manager=zeroconf_manager ) info.async_request.assert_called_once() - async_zeroconf.async_close.assert_not_called() assert ret == addr_infos @@ -131,7 +121,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = addr_infos ret = await hr.async_resolve_host("example.local", 6052) - resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) + resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_not_called() assert ret == addr_infos[0] @@ -144,7 +134,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = addr_infos ret = await hr.async_resolve_host("example.local", 6052) - resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None) + resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_called_once_with("example.local", 6052) assert ret == addr_infos[0] @@ -189,3 +179,40 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc): proto=6, sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052), ) + + +@pytest.mark.asyncio +async def test_resolve_host_zeroconf_service_info_oserror( + async_zeroconf: AsyncZeroconf, addr_infos +): + info = MagicMock(auto_spec=AsyncServiceInfo) + info.ip_addresses_by_version.return_value = [ + ip_address(b"\n\x00\x00*"), + ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), + ] + info.async_request = AsyncMock(return_value=True) + with patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=OSError("out of buffers"), + ), patch( + "aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf + ), pytest.raises( + ResolveAPIError, match="out of buffers" + ): + await hr._async_resolve_host_zeroconf("asdf", 6052) + + +@pytest.mark.asyncio +async def test_resolve_host_create_zeroconf_oserror( + async_zeroconf: AsyncZeroconf, addr_infos +): + info = MagicMock(auto_spec=AsyncServiceInfo) + info.ip_addresses_by_version.return_value = [ + ip_address(b"\n\x00\x00*"), + ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), + ] + info.async_request = AsyncMock(return_value=True) + with patch( + "aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers") + ), pytest.raises(ResolveAPIError, match="out of buffers"): + await hr._async_resolve_host_zeroconf("asdf", 6052) diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index f9c1de4..aec65ca 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -75,21 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec response: message.Message = SubscribeLogsResponse() response.message = b"Hello world" - protocol.data_received( - generate_plaintext_packet( - response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(response)) assert len(messages) == 1 assert messages[0].message == b"Hello world" stop_task = asyncio.create_task(stop()) await asyncio.sleep(0) disconnect_response = DisconnectResponse() - protocol.data_received( - generate_plaintext_packet( - disconnect_response.SerializeToString(), - PROTO_TO_MESSAGE_TYPE[DisconnectResponse], - ) - ) + protocol.data_received(generate_plaintext_packet(disconnect_response)) await stop_task diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index f2ec698..b145c32 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -41,14 +41,13 @@ async def test_reconnect_logic_name_from_host(): async def on_connect() -> None: pass - rl = ReconnectLogic( + ReconnectLogic( client=cli, on_disconnect=on_disconnect, on_connect=on_connect, zeroconf_instance=MagicMock(spec=AsyncZeroconf), ) - assert rl._log_name == "mydevice" - assert cli._log_name == "mydevice" + assert cli.log_name == "mydevice.local" @pytest.mark.asyncio @@ -66,15 +65,14 @@ async def test_reconnect_logic_name_from_host_and_set(): async def on_connect() -> None: pass - rl = ReconnectLogic( + ReconnectLogic( client=cli, on_disconnect=on_disconnect, on_connect=on_connect, zeroconf_instance=get_mock_zeroconf(), name="mydevice", ) - assert rl._log_name == "mydevice" - assert cli._log_name == "mydevice" + assert cli.log_name == "mydevice.local" @pytest.mark.asyncio @@ -92,14 +90,13 @@ async def test_reconnect_logic_name_from_address(): async def on_connect() -> None: pass - rl = ReconnectLogic( + ReconnectLogic( client=cli, on_disconnect=on_disconnect, on_connect=on_connect, zeroconf_instance=get_mock_zeroconf(), ) - assert rl._log_name == "1.2.3.4" - assert cli._log_name == "1.2.3.4" + assert cli.log_name == "1.2.3.4" @pytest.mark.asyncio @@ -117,15 +114,14 @@ async def test_reconnect_logic_name_from_name(): async def on_connect() -> None: pass - rl = ReconnectLogic( + ReconnectLogic( client=cli, on_disconnect=on_disconnect, on_connect=on_connect, zeroconf_instance=get_mock_zeroconf(), name="mydevice", ) - assert rl._log_name == "mydevice @ 1.2.3.4" - assert cli._log_name == "mydevice @ 1.2.3.4" + assert cli.log_name == "mydevice @ 1.2.3.4" @pytest.mark.asyncio @@ -164,8 +160,7 @@ async def test_reconnect_logic_state(): name="mydevice", on_connect_error=on_connect_fail, ) - assert rl._log_name == "mydevice @ 1.2.3.4" - assert cli._log_name == "mydevice @ 1.2.3.4" + assert cli.log_name == "mydevice @ 1.2.3.4" with patch.object(cli, "start_connection", side_effect=APIConnectionError): await rl.start() @@ -241,8 +236,7 @@ async def test_reconnect_retry(): name="mydevice", on_connect_error=on_connect_fail, ) - assert rl._log_name == "mydevice @ 1.2.3.4" - assert cli._log_name == "mydevice @ 1.2.3.4" + assert cli.log_name == "mydevice @ 1.2.3.4" with patch.object(cli, "start_connection", side_effect=APIConnectionError): await rl.start() @@ -338,8 +332,7 @@ async def test_reconnect_zeroconf( name="mydevice", on_connect_error=AsyncMock(), ) - assert rl._log_name == "mydevice @ 1.2.3.4" - assert cli._log_name == "mydevice @ 1.2.3.4" + assert cli.log_name == "mydevice @ 1.2.3.4" async def slow_connect_fail(*args, **kwargs): await asyncio.sleep(10) diff --git a/tests/test_zeroconf.py b/tests/test_zeroconf.py new file mode 100644 index 0000000..fe713b2 --- /dev/null +++ b/tests/test_zeroconf.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from zeroconf.asyncio import AsyncZeroconf + +from aioesphomeapi.zeroconf import ZeroconfManager + +from .common import get_mock_async_zeroconf + + +@pytest.mark.asyncio +async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf): + """Test that the passed in instance is not closed.""" + manager = ZeroconfManager() + manager.set_instance(async_zeroconf) + await manager.async_close() + assert async_zeroconf.async_close.call_count == 0 + + +@pytest.mark.asyncio +async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf): + """Test that the passed in instance is not closed.""" + manager = ZeroconfManager() + manager.set_instance(async_zeroconf.zeroconf) + await manager.async_close() + assert async_zeroconf.async_close.call_count == 0 + + +@pytest.mark.asyncio +async def test_closes_created_instance(async_zeroconf: AsyncZeroconf): + """Test that the created instance is closed.""" + with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf): + manager = ZeroconfManager() + assert manager.get_async_zeroconf() is async_zeroconf + await manager.async_close() + assert async_zeroconf.async_close.call_count == 1 + + +@pytest.mark.asyncio +async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf): + """Test runtime error is raised on multiple instances.""" + manager = ZeroconfManager(async_zeroconf) + new_instance = get_mock_async_zeroconf() + with pytest.raises(RuntimeError): + manager.set_instance(new_instance) + manager.set_instance(async_zeroconf) + manager.set_instance(async_zeroconf.zeroconf) + manager.set_instance(async_zeroconf) + await manager.async_close() + assert async_zeroconf.async_close.call_count == 0