Add test to ensure log runner reconnects on subscribe failure (#757)

This commit is contained in:
J. Nick Koston 2023-11-26 18:25:29 -06:00 committed by GitHub
parent e93ee7f313
commit 86726e9079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 78 additions and 0 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from functools import partial
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -12,6 +13,7 @@ from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectRequest, DisconnectResponse from aioesphomeapi.api_pb2 import DisconnectRequest, DisconnectResponse
from aioesphomeapi.client import APIClient from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.log_runner import async_run from aioesphomeapi.log_runner import async_run
from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN
@ -164,3 +166,79 @@ async def test_log_runner_reconnects_on_disconnect(
assert mock_start_connection.called assert mock_start_connection.called
await stop() await stop()
@pytest.mark.asyncio
async def test_log_runner_reconnects_on_subscribe_failure(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the log runner reconnects on subscribe failure."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
async def _wait_and_fail_subscribe_cli(*args, **kwargs):
subscribed.set()
raise APIConnectionError("subscribed force to fail")
with patch.object(
cli, "disconnect", partial(cli.disconnect, force=True)
), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli):
with patch.object(loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
assert cli._connection is None
with patch.object(loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs"):
connected.clear()
await asyncio.sleep(0)
async_fire_time_changed(
utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN)
)
await asyncio.sleep(0)
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
await stop_task