diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index c28edeb..97df657 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import contextlib +import logging import socket from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address @@ -12,8 +13,12 @@ from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf from .core import APIConnectionError, ResolveAPIError +_LOGGER = logging.getLogger(__name__) + ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None] +SERVICE_TYPE = "_esphomelib._tcp.local." + @dataclass(frozen=True) class Sockaddr: @@ -89,11 +94,11 @@ async def _async_resolve_host_zeroconf( timeout: float = 3.0, zeroconf_instance: ZeroconfInstanceType = None, ) -> list[AddrInfo]: - service_type = "_esphomelib._tcp.local." - service_name = f"{host}.{service_type}" + service_name = f"{host}.{SERVICE_TYPE}" + _LOGGER.debug("Resolving host %s via mDNS", service_name) info = await _async_zeroconf_get_service_info( - zeroconf_instance, service_type, service_name, timeout + zeroconf_instance, SERVICE_TYPE, service_name, timeout ) if info is None: @@ -197,8 +202,8 @@ async def async_resolve_host( addrs: list[AddrInfo] = [] zc_error = None - if host.endswith(".local"): - name = host[: -len(".local")] + if "." not in host or host.endswith(".local"): + name = host.partition(".")[0] try: addrs.extend( await _async_resolve_host_zeroconf( diff --git a/aioesphomeapi/log_reader.py b/aioesphomeapi/log_reader.py index 3ffa4cd..0f53b17 100644 --- a/aioesphomeapi/log_reader.py +++ b/aioesphomeapi/log_reader.py @@ -7,6 +7,8 @@ import logging import sys from datetime import datetime +from zeroconf.asyncio import AsyncZeroconf + from .api_pb2 import SubscribeLogsResponse # type: ignore from .client import APIClient from .log_runner import async_run @@ -27,11 +29,14 @@ async def main(argv: list[str]) -> None: datefmt="%Y-%m-%d %H:%M:%S", ) + aiozc = AsyncZeroconf() + cli = APIClient( args.address, args.port, args.password or "", noise_psk=args.noise_psk, + zeroconf_instance=aiozc.zeroconf, keepalive=10, ) @@ -41,11 +46,12 @@ async def main(argv: list[str]) -> None: text = message.decode("utf8", "backslashreplace") print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") - stop = await async_run(cli, on_log) + stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc) try: while True: await asyncio.sleep(60) finally: + await aiozc.async_close() await stop() diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index 92442e8..f9c1de4 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -34,12 +34,15 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec class PatchableAPIClient(APIClient): pass + async_zeroconf = get_mock_async_zeroconf() + cli = PatchableAPIClient( address=Estr("1.2.3.4"), port=6052, password=None, noise_psk=None, expected_name=Estr("fake"), + zeroconf_instance=async_zeroconf.zeroconf, ) messages = [] @@ -60,8 +63,6 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec await original_subscribe_logs(*args, **kwargs) subscribed.set() - async_zeroconf = get_mock_async_zeroconf() - with patch.object(event_loop, "sock_connect"), patch.object( loop, "create_connection", side_effect=_create_mock_transport_protocol ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):