aioesphomeapi/tests/test_log_runner.py

91 lines
2.8 KiB
Python

import asyncio
from unittest.mock import MagicMock, patch
import pytest
from google.protobuf import message
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectResponse
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run_logs
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
send_plaintext_connect_response,
send_plaintext_hello,
)
@pytest.mark.asyncio
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
"""Test the log runner logic."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
original_subscribe_logs = cli.subscribe_logs
async def _wait_subscribe_cli(*args, **kwargs):
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world"
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
)
)
assert len(messages) == 1
assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
protocol.data_received(
generate_plaintext_packet(
disconnect_response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
)
)
await stop_task