Make log runner code reusable and add coverage (#630)

This commit is contained in:
J. Nick Koston 2023-11-11 13:06:27 -06:00 committed by GitHub
parent 89f34cbcad
commit 3ffcca3bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 211 additions and 90 deletions

View File

@ -4,6 +4,7 @@ source = aioesphomeapi
omit = omit =
aioesphomeapi/api_options_pb2.py aioesphomeapi/api_options_pb2.py
aioesphomeapi/api_pb2.py aioesphomeapi/api_pb2.py
aioesphomeapi/log_reader.py
bench/*.py bench/*.py
[report] [report]

View File

@ -7,15 +7,9 @@ import logging
import sys import sys
from datetime import datetime from datetime import datetime
import zeroconf from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore from .log_runner import async_run_logs
from aioesphomeapi.client import APIClient
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import LogLevel
from aioesphomeapi.reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def main(argv: list[str]) -> None: async def main(argv: list[str]) -> None:
@ -42,42 +36,27 @@ async def main(argv: list[str]) -> None:
) )
def on_log(msg: SubscribeLogsResponse) -> None: def on_log(msg: SubscribeLogsResponse) -> None:
time_ = datetime.now().time().strftime("[%H:%M:%S]") time_ = datetime.now()
text = msg.message message: bytes = msg.message
print(time_ + text.decode("utf8", "backslashreplace")) text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
has_connects = False stop = await async_run_logs(cli, on_log)
async def on_connect() -> None:
nonlocal has_connects
try:
await cli.subscribe_logs(
on_log,
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
dump_config=not has_connects,
)
has_connects = True
except APIConnectionError:
await cli.disconnect()
async def on_disconnect( # pylint: disable=unused-argument
expected_disconnect: bool,
) -> None:
_LOGGER.warning("Disconnected from API")
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf.Zeroconf(),
)
await logic.start()
try: try:
while True: while True:
await asyncio.sleep(60) await asyncio.sleep(60)
finally:
await stop()
def cli_entry_point() -> None:
"""Run the CLI."""
try:
asyncio.run(main(sys.argv))
except KeyboardInterrupt: except KeyboardInterrupt:
await logic.stop() pass
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(asyncio.run(main(sys.argv)) or 0) cli_entry_point()
sys.exit(0)

View File

@ -0,0 +1,61 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Coroutine
import zeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .core import APIConnectionError
from .model import LogLevel
from .reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def async_run_logs(
cli: APIClient,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
zeroconf_instance: zeroconf.Zeroconf | None = None,
dump_config: bool = True,
) -> Callable[[], Coroutine[Any, Any, None]]:
"""Run logs until canceled.
Returns a coroutine that can be awaited to stop the logs.
"""
dumped_config = not dump_config
async def on_connect() -> None:
"""Handle a connection."""
nonlocal dumped_config
try:
await cli.subscribe_logs(
on_log,
log_level=log_level,
dump_config=not dumped_config,
)
dumped_config = True
except APIConnectionError:
await cli.disconnect()
async def on_disconnect( # pylint: disable=unused-argument
expected_disconnect: bool,
) -> None:
_LOGGER.warning("Disconnected from API")
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(),
)
await logic.start()
async def _stop() -> None:
await logic.stop()
await cli.disconnect()
return _stop

View File

@ -1,11 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""aioesphomeapi setup script.""" """aioesphomeapi setup script."""
import os import os
from setuptools import find_packages, setup
import os
from distutils.command.build_ext import build_ext from distutils.command.build_ext import build_ext
from setuptools import find_packages, setup
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
@ -60,6 +58,11 @@ setup_kwargs = {
"install_requires": REQUIRES, "install_requires": REQUIRES,
"python_requires": ">=3.9", "python_requires": ">=3.9",
"test_suite": "tests", "test_suite": "tests",
"entry_points": {
"console_scripts": [
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
],
},
} }

View File

@ -6,9 +6,12 @@ from datetime import datetime, timezone
from functools import partial from functools import partial
from unittest.mock import MagicMock from unittest.mock import MagicMock
from google.protobuf import message
from zeroconf import Zeroconf from zeroconf import Zeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
@ -81,3 +84,26 @@ async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts.""" """Wrapper for connection logic to do both parts."""
await conn.start_connection() await conn.start_connection()
await conn.finish_connection(login=login) await conn.finish_connection(login=login)
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
def send_plaintext_connect_response(
protocol: APIPlaintextFrameHelper, invalid_password: bool
) -> None:
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)

View File

@ -7,15 +7,13 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from google.protobuf import message
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import HelloResponse
from aioesphomeapi.client import APIClient, ConnectionParams from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet from .common import connect, send_plaintext_hello
@pytest.fixture @pytest.fixture
@ -132,14 +130,7 @@ async def api_client(
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
hello_response: message.Message = HelloResponse() send_plaintext_hello(protocol)
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
client._connection = conn client._connection = conn
await connect_task await connect_task
transport.reset_mock() transport.reset_mock()

View File

@ -7,11 +7,9 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from google.protobuf import message
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import ( from aioesphomeapi.api_pb2 import (
ConnectResponse,
DeviceInfoResponse, DeviceInfoResponse,
HelloResponse, HelloResponse,
PingRequest, PingRequest,
@ -27,10 +25,10 @@ from aioesphomeapi.core import (
) )
from .common import ( from .common import (
PROTO_TO_MESSAGE_TYPE,
async_fire_time_changed, async_fire_time_changed,
connect, connect,
generate_plaintext_packet, send_plaintext_connect_response,
send_plaintext_hello,
utcnow, utcnow,
) )
@ -441,22 +439,8 @@ async def test_connect_wrong_password(
) -> None: ) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login conn, transport, protocol, connect_task = plaintext_connect_task_with_login
hello_response: message.Message = HelloResponse() send_plaintext_hello(protocol)
hello_response.api_version_major = 1 send_plaintext_connect_response(protocol, True)
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = True
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
with pytest.raises(InvalidAuthAPIError): with pytest.raises(InvalidAuthAPIError):
await connect_task await connect_task
@ -472,22 +456,8 @@ async def test_connect_correct_password(
) -> None: ) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login conn, transport, protocol, connect_task = plaintext_connect_task_with_login
hello_response: message.Message = HelloResponse() send_plaintext_hello(protocol)
hello_response.api_version_major = 1 send_plaintext_connect_response(protocol, False)
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = False
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
await connect_task await connect_task

90
tests/test_log_runner.py Normal file
View File

@ -0,0 +1,90 @@
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from google.protobuf import message
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectResponse
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run_logs
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
send_plaintext_connect_response,
send_plaintext_hello,
)
@pytest.mark.asyncio
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
"""Test the log runner logic."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
original_subscribe_logs = cli.subscribe_logs
async def _wait_subscribe_cli(*args, **kwargs):
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world"
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
)
)
assert len(messages) == 1
assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
protocol.data_received(
generate_plaintext_packet(
disconnect_response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
)
)
await stop_task