mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-25 12:35:19 +01:00
Make log runner code reusable and add coverage (#630)
This commit is contained in:
parent
89f34cbcad
commit
3ffcca3bdd
@ -4,6 +4,7 @@ source = aioesphomeapi
|
|||||||
omit =
|
omit =
|
||||||
aioesphomeapi/api_options_pb2.py
|
aioesphomeapi/api_options_pb2.py
|
||||||
aioesphomeapi/api_pb2.py
|
aioesphomeapi/api_pb2.py
|
||||||
|
aioesphomeapi/log_reader.py
|
||||||
bench/*.py
|
bench/*.py
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
|
@ -7,15 +7,9 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import zeroconf
|
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||||
|
from .client import APIClient
|
||||||
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
from .log_runner import async_run_logs
|
||||||
from aioesphomeapi.client import APIClient
|
|
||||||
from aioesphomeapi.core import APIConnectionError
|
|
||||||
from aioesphomeapi.model import LogLevel
|
|
||||||
from aioesphomeapi.reconnect_logic import ReconnectLogic
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def main(argv: list[str]) -> None:
|
async def main(argv: list[str]) -> None:
|
||||||
@ -42,42 +36,27 @@ async def main(argv: list[str]) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_log(msg: SubscribeLogsResponse) -> None:
|
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||||
time_ = datetime.now().time().strftime("[%H:%M:%S]")
|
time_ = datetime.now()
|
||||||
text = msg.message
|
message: bytes = msg.message
|
||||||
print(time_ + text.decode("utf8", "backslashreplace"))
|
text = message.decode("utf8", "backslashreplace")
|
||||||
|
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
|
||||||
|
|
||||||
has_connects = False
|
stop = await async_run_logs(cli, on_log)
|
||||||
|
|
||||||
async def on_connect() -> None:
|
|
||||||
nonlocal has_connects
|
|
||||||
try:
|
|
||||||
await cli.subscribe_logs(
|
|
||||||
on_log,
|
|
||||||
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
|
|
||||||
dump_config=not has_connects,
|
|
||||||
)
|
|
||||||
has_connects = True
|
|
||||||
except APIConnectionError:
|
|
||||||
await cli.disconnect()
|
|
||||||
|
|
||||||
async def on_disconnect( # pylint: disable=unused-argument
|
|
||||||
expected_disconnect: bool,
|
|
||||||
) -> None:
|
|
||||||
_LOGGER.warning("Disconnected from API")
|
|
||||||
|
|
||||||
logic = ReconnectLogic(
|
|
||||||
client=cli,
|
|
||||||
on_connect=on_connect,
|
|
||||||
on_disconnect=on_disconnect,
|
|
||||||
zeroconf_instance=zeroconf.Zeroconf(),
|
|
||||||
)
|
|
||||||
await logic.start()
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
finally:
|
||||||
|
await stop()
|
||||||
|
|
||||||
|
|
||||||
|
def cli_entry_point() -> None:
|
||||||
|
"""Run the CLI."""
|
||||||
|
try:
|
||||||
|
asyncio.run(main(sys.argv))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
await logic.stop()
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(asyncio.run(main(sys.argv)) or 0)
|
cli_entry_point()
|
||||||
|
sys.exit(0)
|
||||||
|
61
aioesphomeapi/log_runner.py
Normal file
61
aioesphomeapi/log_runner.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
|
import zeroconf
|
||||||
|
|
||||||
|
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||||
|
from .client import APIClient
|
||||||
|
from .core import APIConnectionError
|
||||||
|
from .model import LogLevel
|
||||||
|
from .reconnect_logic import ReconnectLogic
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_run_logs(
|
||||||
|
cli: APIClient,
|
||||||
|
on_log: Callable[[SubscribeLogsResponse], None],
|
||||||
|
log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
|
||||||
|
zeroconf_instance: zeroconf.Zeroconf | None = None,
|
||||||
|
dump_config: bool = True,
|
||||||
|
) -> Callable[[], Coroutine[Any, Any, None]]:
|
||||||
|
"""Run logs until canceled.
|
||||||
|
|
||||||
|
Returns a coroutine that can be awaited to stop the logs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dumped_config = not dump_config
|
||||||
|
|
||||||
|
async def on_connect() -> None:
|
||||||
|
"""Handle a connection."""
|
||||||
|
nonlocal dumped_config
|
||||||
|
try:
|
||||||
|
await cli.subscribe_logs(
|
||||||
|
on_log,
|
||||||
|
log_level=log_level,
|
||||||
|
dump_config=not dumped_config,
|
||||||
|
)
|
||||||
|
dumped_config = True
|
||||||
|
except APIConnectionError:
|
||||||
|
await cli.disconnect()
|
||||||
|
|
||||||
|
async def on_disconnect( # pylint: disable=unused-argument
|
||||||
|
expected_disconnect: bool,
|
||||||
|
) -> None:
|
||||||
|
_LOGGER.warning("Disconnected from API")
|
||||||
|
|
||||||
|
logic = ReconnectLogic(
|
||||||
|
client=cli,
|
||||||
|
on_connect=on_connect,
|
||||||
|
on_disconnect=on_disconnect,
|
||||||
|
zeroconf_instance=zeroconf_instance or zeroconf.Zeroconf(),
|
||||||
|
)
|
||||||
|
await logic.start()
|
||||||
|
|
||||||
|
async def _stop() -> None:
|
||||||
|
await logic.stop()
|
||||||
|
await cli.disconnect()
|
||||||
|
|
||||||
|
return _stop
|
9
setup.py
9
setup.py
@ -1,11 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""aioesphomeapi setup script."""
|
"""aioesphomeapi setup script."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
|
||||||
import os
|
|
||||||
from distutils.command.build_ext import build_ext
|
from distutils.command.build_ext import build_ext
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
here = os.path.abspath(os.path.dirname(__file__))
|
here = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
@ -60,6 +58,11 @@ setup_kwargs = {
|
|||||||
"install_requires": REQUIRES,
|
"install_requires": REQUIRES,
|
||||||
"python_requires": ">=3.9",
|
"python_requires": ">=3.9",
|
||||||
"test_suite": "tests",
|
"test_suite": "tests",
|
||||||
|
"entry_points": {
|
||||||
|
"console_scripts": [
|
||||||
|
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
|
||||||
|
],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,9 +6,12 @@ from datetime import datetime, timezone
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from google.protobuf import message
|
||||||
from zeroconf import Zeroconf
|
from zeroconf import Zeroconf
|
||||||
|
|
||||||
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||||
|
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
||||||
from aioesphomeapi.connection import APIConnection
|
from aioesphomeapi.connection import APIConnection
|
||||||
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
|
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
|
||||||
|
|
||||||
@ -81,3 +84,26 @@ async def connect(conn: APIConnection, login: bool = True):
|
|||||||
"""Wrapper for connection logic to do both parts."""
|
"""Wrapper for connection logic to do both parts."""
|
||||||
await conn.start_connection()
|
await conn.start_connection()
|
||||||
await conn.finish_connection(login=login)
|
await conn.finish_connection(login=login)
|
||||||
|
|
||||||
|
|
||||||
|
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
||||||
|
hello_response: message.Message = HelloResponse()
|
||||||
|
hello_response.api_version_major = 1
|
||||||
|
hello_response.api_version_minor = 9
|
||||||
|
hello_response.name = "fake"
|
||||||
|
hello_msg = hello_response.SerializeToString()
|
||||||
|
protocol.data_received(
|
||||||
|
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def send_plaintext_connect_response(
|
||||||
|
protocol: APIPlaintextFrameHelper, invalid_password: bool
|
||||||
|
) -> None:
|
||||||
|
connect_response: message.Message = ConnectResponse()
|
||||||
|
connect_response.invalid_password = invalid_password
|
||||||
|
connect_msg = connect_response.SerializeToString()
|
||||||
|
|
||||||
|
protocol.data_received(
|
||||||
|
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||||
|
)
|
||||||
|
@ -7,15 +7,13 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from google.protobuf import message
|
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi.api_pb2 import HelloResponse
|
|
||||||
from aioesphomeapi.client import APIClient, ConnectionParams
|
from aioesphomeapi.client import APIClient, ConnectionParams
|
||||||
from aioesphomeapi.connection import APIConnection
|
from aioesphomeapi.connection import APIConnection
|
||||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||||
|
|
||||||
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet
|
from .common import connect, send_plaintext_hello
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -132,14 +130,7 @@ async def api_client(
|
|||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
hello_response: message.Message = HelloResponse()
|
send_plaintext_hello(protocol)
|
||||||
hello_response.api_version_major = 1
|
|
||||||
hello_response.api_version_minor = 9
|
|
||||||
hello_response.name = "fake"
|
|
||||||
hello_msg = hello_response.SerializeToString()
|
|
||||||
protocol.data_received(
|
|
||||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
|
||||||
)
|
|
||||||
client._connection = conn
|
client._connection = conn
|
||||||
await connect_task
|
await connect_task
|
||||||
transport.reset_mock()
|
transport.reset_mock()
|
||||||
|
@ -7,11 +7,9 @@ from typing import Any
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from google.protobuf import message
|
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi.api_pb2 import (
|
from aioesphomeapi.api_pb2 import (
|
||||||
ConnectResponse,
|
|
||||||
DeviceInfoResponse,
|
DeviceInfoResponse,
|
||||||
HelloResponse,
|
HelloResponse,
|
||||||
PingRequest,
|
PingRequest,
|
||||||
@ -27,10 +25,10 @@ from aioesphomeapi.core import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .common import (
|
from .common import (
|
||||||
PROTO_TO_MESSAGE_TYPE,
|
|
||||||
async_fire_time_changed,
|
async_fire_time_changed,
|
||||||
connect,
|
connect,
|
||||||
generate_plaintext_packet,
|
send_plaintext_connect_response,
|
||||||
|
send_plaintext_hello,
|
||||||
utcnow,
|
utcnow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -441,22 +439,8 @@ async def test_connect_wrong_password(
|
|||||||
) -> None:
|
) -> None:
|
||||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
|
|
||||||
hello_response: message.Message = HelloResponse()
|
send_plaintext_hello(protocol)
|
||||||
hello_response.api_version_major = 1
|
send_plaintext_connect_response(protocol, True)
|
||||||
hello_response.api_version_minor = 9
|
|
||||||
hello_response.name = "fake"
|
|
||||||
hello_msg = hello_response.SerializeToString()
|
|
||||||
|
|
||||||
connect_response: message.Message = ConnectResponse()
|
|
||||||
connect_response.invalid_password = True
|
|
||||||
connect_msg = connect_response.SerializeToString()
|
|
||||||
|
|
||||||
protocol.data_received(
|
|
||||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
|
||||||
)
|
|
||||||
protocol.data_received(
|
|
||||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(InvalidAuthAPIError):
|
with pytest.raises(InvalidAuthAPIError):
|
||||||
await connect_task
|
await connect_task
|
||||||
@ -472,22 +456,8 @@ async def test_connect_correct_password(
|
|||||||
) -> None:
|
) -> None:
|
||||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
|
|
||||||
hello_response: message.Message = HelloResponse()
|
send_plaintext_hello(protocol)
|
||||||
hello_response.api_version_major = 1
|
send_plaintext_connect_response(protocol, False)
|
||||||
hello_response.api_version_minor = 9
|
|
||||||
hello_response.name = "fake"
|
|
||||||
hello_msg = hello_response.SerializeToString()
|
|
||||||
|
|
||||||
connect_response: message.Message = ConnectResponse()
|
|
||||||
connect_response.invalid_password = False
|
|
||||||
connect_msg = connect_response.SerializeToString()
|
|
||||||
|
|
||||||
protocol.data_received(
|
|
||||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
|
||||||
)
|
|
||||||
protocol.data_received(
|
|
||||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
|
||||||
)
|
|
||||||
|
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
|
90
tests/test_log_runner.py
Normal file
90
tests/test_log_runner.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from google.protobuf import message
|
||||||
|
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||||
|
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||||
|
from aioesphomeapi.api_pb2 import DisconnectResponse
|
||||||
|
from aioesphomeapi.client import APIClient
|
||||||
|
from aioesphomeapi.connection import APIConnection
|
||||||
|
from aioesphomeapi.log_runner import async_run_logs
|
||||||
|
|
||||||
|
from .common import (
|
||||||
|
PROTO_TO_MESSAGE_TYPE,
|
||||||
|
Estr,
|
||||||
|
generate_plaintext_packet,
|
||||||
|
get_mock_zeroconf,
|
||||||
|
send_plaintext_connect_response,
|
||||||
|
send_plaintext_hello,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnection):
|
||||||
|
"""Test the log runner logic."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
protocol: APIPlaintextFrameHelper | None = None
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
|
||||||
|
class PatchableAPIClient(APIClient):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cli = PatchableAPIClient(
|
||||||
|
address=Estr("1.2.3.4"),
|
||||||
|
port=6052,
|
||||||
|
password=None,
|
||||||
|
noise_psk=None,
|
||||||
|
expected_name=Estr("fake"),
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||||
|
nonlocal protocol
|
||||||
|
protocol = create_func()
|
||||||
|
protocol.connection_made(transport)
|
||||||
|
connected.set()
|
||||||
|
return transport, protocol
|
||||||
|
|
||||||
|
subscribed = asyncio.Event()
|
||||||
|
original_subscribe_logs = cli.subscribe_logs
|
||||||
|
|
||||||
|
async def _wait_subscribe_cli(*args, **kwargs):
|
||||||
|
await original_subscribe_logs(*args, **kwargs)
|
||||||
|
subscribed.set()
|
||||||
|
|
||||||
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
|
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||||
|
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
|
||||||
|
stop = await async_run_logs(cli, on_log, zeroconf_instance=get_mock_zeroconf())
|
||||||
|
await connected.wait()
|
||||||
|
protocol = cli._connection._frame_helper
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
await subscribed.wait()
|
||||||
|
|
||||||
|
response: message.Message = SubscribeLogsResponse()
|
||||||
|
response.message = b"Hello world"
|
||||||
|
protocol.data_received(
|
||||||
|
generate_plaintext_packet(
|
||||||
|
response.SerializeToString(),
|
||||||
|
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0].message == b"Hello world"
|
||||||
|
stop_task = asyncio.create_task(stop())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
disconnect_response = DisconnectResponse()
|
||||||
|
protocol.data_received(
|
||||||
|
generate_plaintext_packet(
|
||||||
|
disconnect_response.SerializeToString(),
|
||||||
|
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await stop_task
|
Loading…
Reference in New Issue
Block a user