diff --git a/.coveragerc b/.coveragerc index 8a595e7..dba3dc7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,6 +4,7 @@ source = aioesphomeapi omit = aioesphomeapi/api_options_pb2.py aioesphomeapi/api_pb2.py + aioesphomeapi/log_reader.py bench/*.py [report] diff --git a/aioesphomeapi/log_reader.py b/aioesphomeapi/log_reader.py index b55d9e4..cb81e0d 100644 --- a/aioesphomeapi/log_reader.py +++ b/aioesphomeapi/log_reader.py @@ -7,15 +7,9 @@ import logging import sys from datetime import datetime -import zeroconf - -from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore -from aioesphomeapi.client import APIClient -from aioesphomeapi.core import APIConnectionError -from aioesphomeapi.model import LogLevel -from aioesphomeapi.reconnect_logic import ReconnectLogic - -_LOGGER = logging.getLogger(__name__) +from .api_pb2 import SubscribeLogsResponse # type: ignore +from .client import APIClient +from .log_runner import async_run_logs async def main(argv: list[str]) -> None: @@ -42,42 +36,27 @@ async def main(argv: list[str]) -> None: ) def on_log(msg: SubscribeLogsResponse) -> None: - time_ = datetime.now().time().strftime("[%H:%M:%S]") - text = msg.message - print(time_ + text.decode("utf8", "backslashreplace")) + time_ = datetime.now() + message: bytes = msg.message + text = message.decode("utf8", "backslashreplace") + print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") - has_connects = False - - async def on_connect() -> None: - nonlocal has_connects - try: - await cli.subscribe_logs( - on_log, - log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE, - dump_config=not has_connects, - ) - has_connects = True - except APIConnectionError: - await cli.disconnect() - - async def on_disconnect( # pylint: disable=unused-argument - expected_disconnect: bool, - ) -> None: - _LOGGER.warning("Disconnected from API") - - logic = ReconnectLogic( - client=cli, - on_connect=on_connect, - on_disconnect=on_disconnect, - zeroconf_instance=zeroconf.Zeroconf(), - ) - await logic.start() + stop = await async_run_logs(cli, on_log) try: while True: await asyncio.sleep(60) + finally: + await stop() + + +def cli_entry_point() -> None: + """Run the CLI.""" + try: + asyncio.run(main(sys.argv)) except KeyboardInterrupt: - await logic.stop() + pass if __name__ == "__main__": - sys.exit(asyncio.run(main(sys.argv)) or 0) + cli_entry_point() + sys.exit(0) diff --git a/aioesphomeapi/log_runner.py b/aioesphomeapi/log_runner.py new file mode 100644 index 0000000..9580e51 --- /dev/null +++ b/aioesphomeapi/log_runner.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable, Coroutine + +import zeroconf + +from .api_pb2 import SubscribeLogsResponse # type: ignore +from .client import APIClient +from .core import APIConnectionError +from .model import LogLevel +from .reconnect_logic import ReconnectLogic + +_LOGGER = logging.getLogger(__name__) + + +async def async_run_logs( + cli: APIClient, + on_log: Callable[[SubscribeLogsResponse], None], + log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE, + zeroconf_instance: zeroconf.Zeroconf | None = None, + dump_config: bool = True, +) -> Callable[[], Coroutine[Any, Any, None]]: + """Run logs until canceled. + + Returns a coroutine that can be awaited to stop the logs. + """ + + dumped_config = not dump_config + + async def on_connect() -> None: + """Handle a connection.""" + nonlocal dumped_config + try: + await cli.subscribe_logs( + on_log, + log_level=log_level, + dump_config=not dumped_config, + ) + dumped_config = True + except APIConnectionError: + await cli.disconnect() + + async def on_disconnect( # pylint: disable=unused-argument + expected_disconnect: bool, + ) -> None: + _LOGGER.warning("Disconnected from API") + + logic = ReconnectLogic( + client=cli, + on_connect=on_connect, + on_disconnect=on_disconnect, + zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(), + ) + await logic.start() + + async def _stop() -> None: + await logic.stop() + await cli.disconnect() + + return _stop diff --git a/setup.py b/setup.py index 21499ea..f5c8cd1 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,9 @@ #!/usr/bin/env python3 """aioesphomeapi setup script.""" import os - -from setuptools import find_packages, setup -import os from distutils.command.build_ext import build_ext +from setuptools import find_packages, setup here = os.path.abspath(os.path.dirname(__file__)) @@ -60,6 +58,11 @@ setup_kwargs = { "install_requires": REQUIRES, "python_requires": ">=3.9", "test_suite": "tests", + "entry_points": { + "console_scripts": [ + "aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point" + ], + }, } diff --git a/tests/common.py b/tests/common.py index 17b0118..5a43f80 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,9 +6,12 @@ from datetime import datetime, timezone from functools import partial from unittest.mock import MagicMock +from google.protobuf import message from zeroconf import Zeroconf +from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes +from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse from aioesphomeapi.connection import APIConnection from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO @@ -81,3 +84,26 @@ async def connect(conn: APIConnection, login: bool = True): """Wrapper for connection logic to do both parts.""" await conn.start_connection() await conn.finish_connection(login=login) + + +def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None: + hello_response: message.Message = HelloResponse() + hello_response.api_version_major = 1 + hello_response.api_version_minor = 9 + hello_response.name = "fake" + hello_msg = hello_response.SerializeToString() + protocol.data_received( + generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse]) + ) + + +def send_plaintext_connect_response( + protocol: APIPlaintextFrameHelper, invalid_password: bool +) -> None: + connect_response: message.Message = ConnectResponse() + connect_response.invalid_password = invalid_password + connect_msg = connect_response.SerializeToString() + + protocol.data_received( + generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]) + ) diff --git a/tests/conftest.py b/tests/conftest.py index 0c54616..1a0ebeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,15 +7,13 @@ from unittest.mock import MagicMock, patch import pytest import pytest_asyncio -from google.protobuf import message from aioesphomeapi._frame_helper import APIPlaintextFrameHelper -from aioesphomeapi.api_pb2 import HelloResponse from aioesphomeapi.client import APIClient, ConnectionParams from aioesphomeapi.connection import APIConnection from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr -from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet +from .common import connect, send_plaintext_hello @pytest.fixture @@ -132,14 +130,7 @@ async def api_client( ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() - hello_response: message.Message = HelloResponse() - hello_response.api_version_major = 1 - hello_response.api_version_minor = 9 - hello_response.name = "fake" - hello_msg = hello_response.SerializeToString() - protocol.data_received( - generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse]) - ) + send_plaintext_hello(protocol) client._connection = conn await connect_task transport.reset_mock() diff --git a/tests/test_connection.py b/tests/test_connection.py index 5caabec..918088e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,11 +7,9 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -from google.protobuf import message from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi.api_pb2 import ( - ConnectResponse, DeviceInfoResponse, HelloResponse, PingRequest, @@ -27,10 +25,10 @@ from aioesphomeapi.core import ( ) from .common import ( - PROTO_TO_MESSAGE_TYPE, async_fire_time_changed, connect, - generate_plaintext_packet, + send_plaintext_connect_response, + send_plaintext_hello, utcnow, ) @@ -441,22 +439,8 @@ async def test_connect_wrong_password( ) -> None: conn, transport, protocol, connect_task = plaintext_connect_task_with_login - hello_response: message.Message = HelloResponse() - hello_response.api_version_major = 1 - hello_response.api_version_minor = 9 - hello_response.name = "fake" - hello_msg = hello_response.SerializeToString() - - connect_response: message.Message = ConnectResponse() - connect_response.invalid_password = True - connect_msg = connect_response.SerializeToString() - - protocol.data_received( - generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse]) - ) - protocol.data_received( - generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]) - ) + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, True) with pytest.raises(InvalidAuthAPIError): await connect_task @@ -472,22 +456,8 @@ async def test_connect_correct_password( ) -> None: conn, transport, protocol, connect_task = plaintext_connect_task_with_login - hello_response: message.Message = HelloResponse() - hello_response.api_version_major = 1 - hello_response.api_version_minor = 9 - hello_response.name = "fake" - hello_msg = hello_response.SerializeToString() - - connect_response: message.Message = ConnectResponse() - connect_response.invalid_password = False - connect_msg = connect_response.SerializeToString() - - protocol.data_received( - generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse]) - ) - protocol.data_received( - generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]) - ) + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) await connect_task diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py new file mode 100644 index 0000000..3a748ab --- /dev/null +++ b/tests/test_log_runner.py @@ -0,0 +1,90 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +from google.protobuf import message + +from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper +from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore +from aioesphomeapi.api_pb2 import DisconnectResponse +from aioesphomeapi.client import APIClient +from aioesphomeapi.connection import APIConnection +from aioesphomeapi.log_runner import async_run_logs + +from .common import ( + PROTO_TO_MESSAGE_TYPE, + Estr, + generate_plaintext_packet, + get_mock_zeroconf, + send_plaintext_connect_response, + send_plaintext_hello, +) + + +@pytest.mark.asyncio +async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection): + """Test the log runner logic.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + cli = PatchableAPIClient( + address=Estr("1.2.3.4"), + port=6052, + password=None, + noise_psk=None, + expected_name=Estr("fake"), + ) + messages = [] + + def on_log(msg: SubscribeLogsResponse) -> None: + messages.append(msg) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + subscribed = asyncio.Event() + original_subscribe_logs = cli.subscribe_logs + + async def _wait_subscribe_cli(*args, **kwargs): + await original_subscribe_logs(*args, **kwargs) + subscribed.set() + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): + stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf()) + await connected.wait() + protocol = cli._connection._frame_helper + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + await subscribed.wait() + + response: message.Message = SubscribeLogsResponse() + response.message = b"Hello world" + protocol.data_received( + generate_plaintext_packet( + response.SerializeToString(), + PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse], + ) + ) + assert len(messages) == 1 + assert messages[0].message == b"Hello world" + stop_task = asyncio.create_task(stop()) + await asyncio.sleep(0) + disconnect_response = DisconnectResponse() + protocol.data_received( + generate_plaintext_packet( + disconnect_response.SerializeToString(), + PROTO_TO_MESSAGE_TYPE[DisconnectResponse], + ) + ) + await stop_task