aioesphomeapi/tests/common.py
2024-08-29 11:08:13 -10:00

172 lines
5.1 KiB
Python

from __future__ import annotations
import asyncio
from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial
import time
from typing import Callable
from unittest.mock import AsyncMock, MagicMock, patch
from google.protobuf import message
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
ConnectResponse,
HelloResponse,
PingRequest,
PingResponse,
)
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
UTC = timezone.utc
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
# We use a partial here since it is implemented in native code
# and avoids the global lookup of UTC
utcnow: partial[datetime] = partial(datetime.now, UTC)
utcnow.__doc__ = "Get now in UTC time."
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
def mock_data_received(
protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes
) -> None:
"""Mock data received on the protocol."""
try:
protocol.data_received(data)
except Exception as err: # pylint: disable=broad-except
loop = asyncio.get_running_loop()
loop.call_soon(
protocol.connection_lost,
err,
)
def get_mock_zeroconf() -> MagicMock:
with patch("zeroconf.Zeroconf.start"):
zc = Zeroconf()
zc.close = MagicMock()
return zc
def get_mock_async_zeroconf() -> AsyncZeroconf:
aiozc = AsyncZeroconf(zc=get_mock_zeroconf())
aiozc.async_close = AsyncMock()
return aiozc
class Estr(str):
"""A subclassed string."""
def generate_plaintext_packet(msg: message.Message) -> bytes:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_)
+ bytes_
)
def as_utc(dattim: datetime) -> datetime:
"""Return a datetime as UTC time."""
if dattim.tzinfo == UTC:
return dattim
return dattim.astimezone(UTC)
def async_fire_time_changed(
datetime_: datetime | None = None, fire_all: bool = False
) -> None:
"""Fire a time changed event at an exact microsecond.
Consider that it is not possible to actually achieve an exact
microsecond in production as the event loop is not precise enough.
If your code relies on this level of precision, consider a different
approach, as this is only for testing.
"""
loop = asyncio.get_running_loop()
if datetime_ is None:
utc_datetime = datetime.now(UTC)
else:
utc_datetime = as_utc(datetime_)
timestamp = utc_datetime.timestamp()
for task in list(loop._scheduled):
if not isinstance(task, asyncio.TimerHandle):
continue
if task.cancelled():
continue
mock_seconds_into_future = timestamp - time.time()
future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)
if fire_all or mock_seconds_into_future >= future_seconds:
task._run()
task.cancel()
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
await conn.start_connection()
await conn.finish_connection(login=login)
async def connect_client(
client: APIClient,
login: bool = True,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
) -> None:
"""Wrapper for connection logic to do both parts."""
await client.start_connection(on_stop=on_stop)
await client.finish_connection(login=login)
def send_plaintext_hello(
protocol: APIPlaintextFrameHelper,
major: int | None = None,
minor: int | None = None,
) -> None:
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1 if major is None else major
hello_response.api_version_minor = 9 if minor is None else minor
hello_response.name = "fake"
protocol.data_received(generate_plaintext_packet(hello_response))
def send_plaintext_connect_response(
protocol: APIPlaintextFrameHelper, invalid_password: bool
) -> None:
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password
protocol.data_received(generate_plaintext_packet(connect_response))
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
ping_response: message.Message = PingResponse()
protocol.data_received(generate_plaintext_packet(ping_response))
def send_ping_request(protocol: APIPlaintextFrameHelper) -> None:
ping_request: message.Message = PingRequest()
protocol.data_received(generate_plaintext_packet(ping_request))
def get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
client_info="mock",
log_name="mock_device",
)
transport = MagicMock()
protocol.connection_made(transport)
return protocol