mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-25 17:17:42 +01:00
Make log runner code reusable and add coverage (#630)
This commit is contained in:
parent
89f34cbcad
commit
3ffcca3bdd
@ -4,6 +4,7 @@ source = aioesphomeapi
|
||||
omit =
|
||||
aioesphomeapi/api_options_pb2.py
|
||||
aioesphomeapi/api_pb2.py
|
||||
aioesphomeapi/log_reader.py
|
||||
bench/*.py
|
||||
|
||||
[report]
|
||||
|
@ -7,15 +7,9 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import zeroconf
|
||||
|
||||
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from aioesphomeapi.model import LogLevel
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from .client import APIClient
|
||||
from .log_runner import async_run_logs
|
||||
|
||||
|
||||
async def main(argv: list[str]) -> None:
|
||||
@ -42,42 +36,27 @@ async def main(argv: list[str]) -> None:
|
||||
)
|
||||
|
||||
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||
time_ = datetime.now().time().strftime("[%H:%M:%S]")
|
||||
text = msg.message
|
||||
print(time_ + text.decode("utf8", "backslashreplace"))
|
||||
time_ = datetime.now()
|
||||
message: bytes = msg.message
|
||||
text = message.decode("utf8", "backslashreplace")
|
||||
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
|
||||
|
||||
has_connects = False
|
||||
|
||||
async def on_connect() -> None:
|
||||
nonlocal has_connects
|
||||
try:
|
||||
await cli.subscribe_logs(
|
||||
on_log,
|
||||
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
|
||||
dump_config=not has_connects,
|
||||
)
|
||||
has_connects = True
|
||||
except APIConnectionError:
|
||||
await cli.disconnect()
|
||||
|
||||
async def on_disconnect( # pylint: disable=unused-argument
|
||||
expected_disconnect: bool,
|
||||
) -> None:
|
||||
_LOGGER.warning("Disconnected from API")
|
||||
|
||||
logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
zeroconf_instance=zeroconf.Zeroconf(),
|
||||
)
|
||||
await logic.start()
|
||||
stop = await async_run_logs(cli, on_log)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
finally:
|
||||
await stop()
|
||||
|
||||
|
||||
def cli_entry_point() -> None:
|
||||
"""Run the CLI."""
|
||||
try:
|
||||
asyncio.run(main(sys.argv))
|
||||
except KeyboardInterrupt:
|
||||
await logic.stop()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main(sys.argv)) or 0)
|
||||
cli_entry_point()
|
||||
sys.exit(0)
|
||||
|
61
aioesphomeapi/log_runner.py
Normal file
61
aioesphomeapi/log_runner.py
Normal file
@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import zeroconf
|
||||
|
||||
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from .client import APIClient
|
||||
from .core import APIConnectionError
|
||||
from .model import LogLevel
|
||||
from .reconnect_logic import ReconnectLogic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_run_logs(
|
||||
cli: APIClient,
|
||||
on_log: Callable[[SubscribeLogsResponse], None],
|
||||
log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
|
||||
zeroconf_instance: zeroconf.Zeroconf | None = None,
|
||||
dump_config: bool = True,
|
||||
) -> Callable[[], Coroutine[Any, Any, None]]:
|
||||
"""Run logs until canceled.
|
||||
|
||||
Returns a coroutine that can be awaited to stop the logs.
|
||||
"""
|
||||
|
||||
dumped_config = not dump_config
|
||||
|
||||
async def on_connect() -> None:
|
||||
"""Handle a connection."""
|
||||
nonlocal dumped_config
|
||||
try:
|
||||
await cli.subscribe_logs(
|
||||
on_log,
|
||||
log_level=log_level,
|
||||
dump_config=not dumped_config,
|
||||
)
|
||||
dumped_config = True
|
||||
except APIConnectionError:
|
||||
await cli.disconnect()
|
||||
|
||||
async def on_disconnect( # pylint: disable=unused-argument
|
||||
expected_disconnect: bool,
|
||||
) -> None:
|
||||
_LOGGER.warning("Disconnected from API")
|
||||
|
||||
logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(),
|
||||
)
|
||||
await logic.start()
|
||||
|
||||
async def _stop() -> None:
|
||||
await logic.stop()
|
||||
await cli.disconnect()
|
||||
|
||||
return _stop
|
9
setup.py
9
setup.py
@ -1,11 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""aioesphomeapi setup script."""
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
import os
|
||||
from distutils.command.build_ext import build_ext
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
@ -60,6 +58,11 @@ setup_kwargs = {
|
||||
"install_requires": REQUIRES,
|
||||
"python_requires": ">=3.9",
|
||||
"test_suite": "tests",
|
||||
"entry_points": {
|
||||
"console_scripts": [
|
||||
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,9 +6,12 @@ from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.protobuf import message
|
||||
from zeroconf import Zeroconf
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
|
||||
|
||||
@ -81,3 +84,26 @@ async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
await conn.start_connection()
|
||||
await conn.finish_connection(login=login)
|
||||
|
||||
|
||||
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
|
||||
|
||||
def send_plaintext_connect_response(
|
||||
protocol: APIPlaintextFrameHelper, invalid_password: bool
|
||||
) -> None:
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = invalid_password
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
|
@ -7,15 +7,13 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import HelloResponse
|
||||
from aioesphomeapi.client import APIClient, ConnectionParams
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet
|
||||
from .common import connect, send_plaintext_hello
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -132,14 +130,7 @@ async def api_client(
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
send_plaintext_hello(protocol)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
|
@ -7,11 +7,9 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
ConnectResponse,
|
||||
DeviceInfoResponse,
|
||||
HelloResponse,
|
||||
PingRequest,
|
||||
@ -27,10 +25,10 @@ from aioesphomeapi.core import (
|
||||
)
|
||||
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
async_fire_time_changed,
|
||||
connect,
|
||||
generate_plaintext_packet,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
utcnow,
|
||||
)
|
||||
|
||||
@ -441,22 +439,8 @@ async def test_connect_wrong_password(
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = True
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, True)
|
||||
|
||||
with pytest.raises(InvalidAuthAPIError):
|
||||
await connect_task
|
||||
@ -472,22 +456,8 @@ async def test_connect_correct_password(
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = False
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
|
||||
|
90
tests/test_log_runner.py
Normal file
90
tests/test_log_runner.py
Normal file
@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from aioesphomeapi.api_pb2 import DisconnectResponse
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.log_runner import async_run_logs
|
||||
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_zeroconf,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
|
||||
"""Test the log runner logic."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address=Estr("1.2.3.4"),
|
||||
port=6052,
|
||||
password=None,
|
||||
noise_psk=None,
|
||||
expected_name=Estr("fake"),
|
||||
)
|
||||
messages = []
|
||||
|
||||
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||
messages.append(msg)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
subscribed = asyncio.Event()
|
||||
original_subscribe_logs = cli.subscribe_logs
|
||||
|
||||
async def _wait_subscribe_cli(*args, **kwargs):
|
||||
await original_subscribe_logs(*args, **kwargs)
|
||||
subscribed.set()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
|
||||
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
|
||||
await connected.wait()
|
||||
protocol = cli._connection._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
await subscribed.wait()
|
||||
|
||||
response: message.Message = SubscribeLogsResponse()
|
||||
response.message = b"Hello world"
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
|
||||
)
|
||||
)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].message == b"Hello world"
|
||||
stop_task = asyncio.create_task(stop())
|
||||
await asyncio.sleep(0)
|
||||
disconnect_response = DisconnectResponse()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
disconnect_response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
|
||||
)
|
||||
)
|
||||
await stop_task
|
Loading…
Reference in New Issue
Block a user