From 86726e90790400c09022b8130b35169bdeca1b1b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 26 Nov 2023 18:25:29 -0600 Subject: [PATCH] Add test to ensure log runner reconnects on subscribe failure (#757) --- tests/test_log_runner.py | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index 5705387..f446d8a 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from datetime import timedelta +from functools import partial from unittest.mock import MagicMock, patch import pytest @@ -12,6 +13,7 @@ from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore from aioesphomeapi.api_pb2 import DisconnectRequest, DisconnectResponse from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection +from aioesphomeapi.core import APIConnectionError from aioesphomeapi.log_runner import async_run from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN @@ -164,3 +166,79 @@ async def test_log_runner_reconnects_on_disconnect( assert mock_start_connection.called await stop() + + +@pytest.mark.asyncio +async def test_log_runner_reconnects_on_subscribe_failure( + event_loop: asyncio.AbstractEventLoop, + conn: APIConnection, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the log runner reconnects on subscribe failure.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + async_zeroconf = get_mock_async_zeroconf() + + cli = PatchableAPIClient( + address=Estr("1.2.3.4"), + port=6052, + password=None, + noise_psk=None, + expected_name=Estr("fake"), + zeroconf_instance=async_zeroconf.zeroconf, + ) + messages = [] + + def on_log(msg: SubscribeLogsResponse) -> None: + messages.append(msg) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + subscribed = asyncio.Event() + + async def _wait_and_fail_subscribe_cli(*args, **kwargs): + subscribed.set() + raise APIConnectionError("subscribed force to fail") + + with patch.object( + cli, "disconnect", partial(cli.disconnect, force=True) + ), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli): + with patch.object(loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) + await connected.wait() + protocol = cli._connection._frame_helper + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await subscribed.wait() + + assert cli._connection is None + + with patch.object(loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), patch.object(cli, "subscribe_logs"): + connected.clear() + await asyncio.sleep(0) + async_fire_time_changed( + utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN) + ) + await asyncio.sleep(0) + + stop_task = asyncio.create_task(stop()) + await asyncio.sleep(0) + disconnect_response = DisconnectResponse() + mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) + await stop_task