Ensure zeroconf instance is closed when log runner ends (#632)

This commit is contained in:
J. Nick Koston 2023-11-11 13:43:31 -06:00 committed by GitHub
parent 338e89069f
commit 0202e00eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 11 deletions

View File

@ -9,7 +9,7 @@ from datetime import datetime
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .log_runner import async_run_logs
from .log_runner import async_run
async def main(argv: list[str]) -> None:
@ -41,7 +41,7 @@ async def main(argv: list[str]) -> None:
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
stop = await async_run_logs(cli, on_log)
stop = await async_run(cli, on_log)
try:
while True:
await asyncio.sleep(60)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import Any, Callable, Coroutine
import zeroconf
from zeroconf.asyncio import AsyncZeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
@ -14,18 +14,17 @@ from .reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def async_run_logs(
async def async_run(
cli: APIClient,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
zeroconf_instance: zeroconf.Zeroconf | None = None,
aio_zeroconf_instance: AsyncZeroconf | None = None,
dump_config: bool = True,
) -> Callable[[], Coroutine[Any, Any, None]]:
"""Run logs until canceled.
Returns a coroutine that can be awaited to stop the logs.
"""
dumped_config = not dump_config
async def on_connect() -> None:
@ -46,15 +45,20 @@ async def async_run_logs(
) -> None:
_LOGGER.warning("Disconnected from API")
passed_in_zeroconf = aio_zeroconf_instance is not None
aiozc = aio_zeroconf_instance or AsyncZeroconf()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(),
zeroconf_instance=aiozc.zeroconf,
)
await logic.start()
async def _stop() -> None:
if not passed_in_zeroconf:
await aiozc.async_close()
await logic.stop()
await cli.disconnect()

View File

@ -4,10 +4,11 @@ import asyncio
import time
from datetime import datetime, timezone
from functools import partial
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from google.protobuf import message
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
@ -29,6 +30,13 @@ def get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
def get_mock_async_zeroconf() -> MagicMock:
mock = MagicMock(spec=AsyncZeroconf)
mock.zeroconf = get_mock_zeroconf()
mock.async_close = AsyncMock()
return mock
class Estr(str):
"""A subclassed string."""

View File

@ -9,13 +9,13 @@ from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectResponse
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run_logs
from aioesphomeapi.log_runner import async_run
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
get_mock_async_zeroconf,
send_plaintext_connect_response,
send_plaintext_hello,
)
@ -58,10 +58,12 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
async_zeroconf = get_mock_async_zeroconf()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)