From 0202e00eaee51f8d84c80eb449cf40493c86f9b0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 13:43:31 -0600 Subject: [PATCH] Ensure zeroconf instance is closed when log runner ends (#632) --- aioesphomeapi/log_reader.py | 4 ++-- aioesphomeapi/log_runner.py | 14 +++++++++----- tests/common.py | 10 +++++++++- tests/test_log_runner.py | 8 +++++--- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/aioesphomeapi/log_reader.py b/aioesphomeapi/log_reader.py index cb81e0d..3ffa4cd 100644 --- a/aioesphomeapi/log_reader.py +++ b/aioesphomeapi/log_reader.py @@ -9,7 +9,7 @@ from datetime import datetime from .api_pb2 import SubscribeLogsResponse # type: ignore from .client import APIClient -from .log_runner import async_run_logs +from .log_runner import async_run async def main(argv: list[str]) -> None: @@ -41,7 +41,7 @@ async def main(argv: list[str]) -> None: text = message.decode("utf8", "backslashreplace") print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") - stop = await async_run_logs(cli, on_log) + stop = await async_run(cli, on_log) try: while True: await asyncio.sleep(60) diff --git a/aioesphomeapi/log_runner.py b/aioesphomeapi/log_runner.py index 9580e51..cf35d9d 100644 --- a/aioesphomeapi/log_runner.py +++ b/aioesphomeapi/log_runner.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from typing import Any, Callable, Coroutine -import zeroconf +from zeroconf.asyncio import AsyncZeroconf from .api_pb2 import SubscribeLogsResponse # type: ignore from .client import APIClient @@ -14,18 +14,17 @@ from .reconnect_logic import ReconnectLogic _LOGGER = logging.getLogger(__name__) -async def async_run_logs( +async def async_run( cli: APIClient, on_log: Callable[[SubscribeLogsResponse], None], log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE, - zeroconf_instance: zeroconf.Zeroconf | None = None, + aio_zeroconf_instance: AsyncZeroconf | None = None, dump_config: bool = True, ) -> Callable[[], Coroutine[Any, Any, None]]: """Run logs until canceled. Returns a coroutine that can be awaited to stop the logs. """ - dumped_config = not dump_config async def on_connect() -> None: @@ -46,15 +45,20 @@ async def async_run_logs( ) -> None: _LOGGER.warning("Disconnected from API") + passed_in_zeroconf = aio_zeroconf_instance is not None + aiozc = aio_zeroconf_instance or AsyncZeroconf() + logic = ReconnectLogic( client=cli, on_connect=on_connect, on_disconnect=on_disconnect, - zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(), + zeroconf_instance=aiozc.zeroconf, ) await logic.start() async def _stop() -> None: + if not passed_in_zeroconf: + await aiozc.async_close() await logic.stop() await cli.disconnect() diff --git a/tests/common.py b/tests/common.py index 5a43f80..414f7d2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,10 +4,11 @@ import asyncio import time from datetime import datetime, timezone from functools import partial -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock from google.protobuf import message from zeroconf import Zeroconf +from zeroconf.asyncio import AsyncZeroconf from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes @@ -29,6 +30,13 @@ def get_mock_zeroconf() -> MagicMock: return MagicMock(spec=Zeroconf) +def get_mock_async_zeroconf() -> MagicMock: + mock = MagicMock(spec=AsyncZeroconf) + mock.zeroconf = get_mock_zeroconf() + mock.async_close = AsyncMock() + return mock + + class Estr(str): """A subclassed string.""" diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index 3a748ab..5ff1686 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -9,13 +9,13 @@ from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore from aioesphomeapi.api_pb2 import DisconnectResponse from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection -from aioesphomeapi.log_runner import async_run_logs +from aioesphomeapi.log_runner import async_run from .common import ( PROTO_TO_MESSAGE_TYPE, Estr, generate_plaintext_packet, - get_mock_zeroconf, + get_mock_async_zeroconf, send_plaintext_connect_response, send_plaintext_hello, ) @@ -58,10 +58,12 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec await original_subscribe_logs(*args, **kwargs) subscribed.set() + async_zeroconf = get_mock_async_zeroconf() + with patch.object(event_loop, "sock_connect"), patch.object( loop, "create_connection", side_effect=_create_mock_transport_protocol ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): - stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf()) + stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) await connected.wait() protocol = cli._connection._frame_helper send_plaintext_hello(protocol)