aioesphomeapi/tests/conftest.py

237 lines
7.2 KiB
Python

"""Test fixtures."""
from __future__ import annotations
import asyncio
import socket
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, create_autospec, patch
import pytest
import pytest_asyncio
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import (
connect,
connect_client,
get_mock_async_zeroconf,
send_plaintext_hello,
)
KEEP_ALIVE_INTERVAL = 15.0
class PatchableAPIConnection(APIConnection):
pass
@pytest.fixture
def async_zeroconf():
return get_mock_async_zeroconf()
@pytest.fixture
def resolve_host():
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
func.return_value = [
AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
]
yield func
@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="127.0.0.1",
port=6052,
password=None,
)
return cli
def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
addresses=["fake.address"],
port=6052,
password=None,
client_info="Tests client",
keepalive=KEEP_ALIVE_INTERVAL,
zeroconf_manager=ZeroconfManager(),
noise_psk=None,
expected_name=None,
)
@pytest.fixture
def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams:
return get_mock_connection_params()
def mock_on_stop(expected_disconnect: bool) -> None:
pass
@pytest.fixture
def conn(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
def conn_with_password(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace(connection_params, password="password")
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
def noise_conn(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace(
connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
)
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
def conn_with_expected_name(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace(connection_params, expected_name="test")
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture()
def aiohappyeyeballs_start_connection(event_loop: asyncio.AbstractEventLoop):
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
func.return_value = mock_socket
yield func
def _create_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
create_func: Callable[[], APIPlaintextFrameHelper],
**kwargs,
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login(
conn: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
yield conn, transport, conn._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
event_loop = asyncio.get_running_loop()
transport = MagicMock()
connected = asyncio.Event()
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(
connect(conn_with_expected_name, login=False)
)
await connected.wait()
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
event_loop = asyncio.get_running_loop()
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
await connected.wait()
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
event_loop = asyncio.get_running_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
client = APIClient(
address="mydevice.local",
port=6052,
password=None,
)
with (
patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
),
patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection),
):
connect_task = asyncio.create_task(connect_client(client, login=False))
await connected.wait()
conn = client._connection
protocol = conn._frame_helper
send_plaintext_hello(protocol)
await connect_task
transport.reset_mock()
yield client, conn, transport, protocol