Make log runner code reusable and add coverage (#630)

This commit is contained in:
J. Nick Koston 2023-11-11 13:06:27 -06:00 committed by GitHub
parent 89f34cbcad
commit 3ffcca3bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 211 additions and 90 deletions

View File

@ -4,6 +4,7 @@ source = aioesphomeapi
omit =
aioesphomeapi/api_options_pb2.py
aioesphomeapi/api_pb2.py
aioesphomeapi/log_reader.py
bench/*.py
[report]

View File

@ -7,15 +7,9 @@ import logging
import sys
from datetime import datetime
import zeroconf
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.client import APIClient
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import LogLevel
from aioesphomeapi.reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .log_runner import async_run_logs
async def main(argv: list[str]) -> None:
@ -42,42 +36,27 @@ async def main(argv: list[str]) -> None:
)
def on_log(msg: SubscribeLogsResponse) -> None:
time_ = datetime.now().time().strftime("[%H:%M:%S]")
text = msg.message
print(time_ + text.decode("utf8", "backslashreplace"))
time_ = datetime.now()
message: bytes = msg.message
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
has_connects = False
async def on_connect() -> None:
nonlocal has_connects
try:
await cli.subscribe_logs(
on_log,
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
dump_config=not has_connects,
)
has_connects = True
except APIConnectionError:
await cli.disconnect()
async def on_disconnect( # pylint: disable=unused-argument
expected_disconnect: bool,
) -> None:
_LOGGER.warning("Disconnected from API")
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf.Zeroconf(),
)
await logic.start()
stop = await async_run_logs(cli, on_log)
try:
while True:
await asyncio.sleep(60)
finally:
await stop()
def cli_entry_point() -> None:
"""Run the CLI."""
try:
asyncio.run(main(sys.argv))
except KeyboardInterrupt:
await logic.stop()
pass
if __name__ == "__main__":
sys.exit(asyncio.run(main(sys.argv)) or 0)
cli_entry_point()
sys.exit(0)

View File

@ -0,0 +1,61 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Coroutine
import zeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .core import APIConnectionError
from .model import LogLevel
from .reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def async_run_logs(
cli: APIClient,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
zeroconf_instance: zeroconf.Zeroconf | None = None,
dump_config: bool = True,
) -> Callable[[], Coroutine[Any, Any, None]]:
"""Run logs until canceled.
Returns a coroutine that can be awaited to stop the logs.
"""
dumped_config = not dump_config
async def on_connect() -> None:
"""Handle a connection."""
nonlocal dumped_config
try:
await cli.subscribe_logs(
on_log,
log_level=log_level,
dump_config=not dumped_config,
)
dumped_config = True
except APIConnectionError:
await cli.disconnect()
async def on_disconnect( # pylint: disable=unused-argument
expected_disconnect: bool,
) -> None:
_LOGGER.warning("Disconnected from API")
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(),
)
await logic.start()
async def _stop() -> None:
await logic.stop()
await cli.disconnect()
return _stop

View File

@ -1,11 +1,9 @@
#!/usr/bin/env python3
"""aioesphomeapi setup script."""
import os
from setuptools import find_packages, setup
import os
from distutils.command.build_ext import build_ext
from setuptools import find_packages, setup
here = os.path.abspath(os.path.dirname(__file__))
@ -60,6 +58,11 @@ setup_kwargs = {
"install_requires": REQUIRES,
"python_requires": ">=3.9",
"test_suite": "tests",
"entry_points": {
"console_scripts": [
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
],
},
}

View File

@ -6,9 +6,12 @@ from datetime import datetime, timezone
from functools import partial
from unittest.mock import MagicMock
from google.protobuf import message
from zeroconf import Zeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
@ -81,3 +84,26 @@ async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
await conn.start_connection()
await conn.finish_connection(login=login)
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
def send_plaintext_connect_response(
protocol: APIPlaintextFrameHelper, invalid_password: bool
) -> None:
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)

View File

@ -7,15 +7,13 @@ from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from google.protobuf import message
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import HelloResponse
from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet
from .common import connect, send_plaintext_hello
@pytest.fixture
@ -132,14 +130,7 @@ async def api_client(
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
send_plaintext_hello(protocol)
client._connection = conn
await connect_task
transport.reset_mock()

View File

@ -7,11 +7,9 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from google.protobuf import message
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import (
ConnectResponse,
DeviceInfoResponse,
HelloResponse,
PingRequest,
@ -27,10 +25,10 @@ from aioesphomeapi.core import (
)
from .common import (
PROTO_TO_MESSAGE_TYPE,
async_fire_time_changed,
connect,
generate_plaintext_packet,
send_plaintext_connect_response,
send_plaintext_hello,
utcnow,
)
@ -441,22 +439,8 @@ async def test_connect_wrong_password(
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = True
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, True)
with pytest.raises(InvalidAuthAPIError):
await connect_task
@ -472,22 +456,8 @@ async def test_connect_correct_password(
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = False
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task

90
tests/test_log_runner.py Normal file
View File

@ -0,0 +1,90 @@
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from google.protobuf import message
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectResponse
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run_logs
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
send_plaintext_connect_response,
send_plaintext_hello,
)
@pytest.mark.asyncio
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
"""Test the log runner logic."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
original_subscribe_logs = cli.subscribe_logs
async def _wait_subscribe_cli(*args, **kwargs):
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world"
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
)
)
assert len(messages) == 1
assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
protocol.data_received(
generate_plaintext_packet(
disconnect_response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
)
)
await stop_task