mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-22 21:41:26 +01:00
237 lines
7.2 KiB
Python
237 lines
7.2 KiB
Python
"""Test fixtures."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import socket
|
|
from dataclasses import replace
|
|
from functools import partial
|
|
from typing import Callable
|
|
from unittest.mock import MagicMock, create_autospec, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
|
from aioesphomeapi.client import APIClient, ConnectionParams
|
|
from aioesphomeapi.connection import APIConnection
|
|
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
|
from aioesphomeapi.zeroconf import ZeroconfManager
|
|
|
|
from .common import (
|
|
connect,
|
|
connect_client,
|
|
get_mock_async_zeroconf,
|
|
send_plaintext_hello,
|
|
)
|
|
|
|
KEEP_ALIVE_INTERVAL = 15.0
|
|
|
|
|
|
class PatchableAPIConnection(APIConnection):
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def async_zeroconf():
|
|
return get_mock_async_zeroconf()
|
|
|
|
|
|
@pytest.fixture
|
|
def resolve_host():
|
|
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
|
|
func.return_value = [
|
|
AddrInfo(
|
|
family=socket.AF_INET,
|
|
type=socket.SOCK_STREAM,
|
|
proto=socket.IPPROTO_TCP,
|
|
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
|
|
)
|
|
]
|
|
yield func
|
|
|
|
|
|
@pytest.fixture
|
|
def patchable_api_client() -> APIClient:
|
|
class PatchableAPIClient(APIClient):
|
|
pass
|
|
|
|
cli = PatchableAPIClient(
|
|
address="127.0.0.1",
|
|
port=6052,
|
|
password=None,
|
|
)
|
|
return cli
|
|
|
|
|
|
def get_mock_connection_params() -> ConnectionParams:
|
|
return ConnectionParams(
|
|
addresses=["fake.address"],
|
|
port=6052,
|
|
password=None,
|
|
client_info="Tests client",
|
|
keepalive=KEEP_ALIVE_INTERVAL,
|
|
zeroconf_manager=ZeroconfManager(),
|
|
noise_psk=None,
|
|
expected_name=None,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams:
|
|
return get_mock_connection_params()
|
|
|
|
|
|
def mock_on_stop(expected_disconnect: bool) -> None:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def conn(
|
|
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
|
|
) -> APIConnection:
|
|
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
|
|
|
|
|
|
@pytest.fixture
|
|
def conn_with_password(
|
|
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
|
|
) -> APIConnection:
|
|
connection_params = replace(connection_params, password="password")
|
|
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
|
|
|
|
|
|
@pytest.fixture
|
|
def noise_conn(
|
|
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
|
|
) -> APIConnection:
|
|
connection_params = replace(
|
|
connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
|
)
|
|
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
|
|
|
|
|
|
@pytest.fixture
|
|
def conn_with_expected_name(
|
|
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
|
|
) -> APIConnection:
|
|
connection_params = replace(connection_params, expected_name="test")
|
|
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
|
|
|
|
|
|
@pytest.fixture()
|
|
def aiohappyeyeballs_start_connection(event_loop: asyncio.AbstractEventLoop):
|
|
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
|
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
|
mock_socket.type = socket.SOCK_STREAM
|
|
mock_socket.fileno.return_value = 1
|
|
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
|
|
func.return_value = mock_socket
|
|
yield func
|
|
|
|
|
|
def _create_mock_transport_protocol(
|
|
transport: asyncio.Transport,
|
|
connected: asyncio.Event,
|
|
create_func: Callable[[], APIPlaintextFrameHelper],
|
|
**kwargs,
|
|
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
|
|
protocol: APIPlaintextFrameHelper = create_func()
|
|
protocol.connection_made(transport)
|
|
connected.set()
|
|
return transport, protocol
|
|
|
|
|
|
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
|
async def plaintext_connect_task_no_login(
|
|
conn: APIConnection,
|
|
resolve_host,
|
|
aiohappyeyeballs_start_connection,
|
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
|
loop = asyncio.get_event_loop()
|
|
transport = MagicMock()
|
|
connected = asyncio.Event()
|
|
|
|
with patch.object(
|
|
loop,
|
|
"create_connection",
|
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
|
):
|
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
|
await connected.wait()
|
|
yield conn, transport, conn._frame_helper, connect_task
|
|
|
|
|
|
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
|
|
async def plaintext_connect_task_no_login_with_expected_name(
|
|
conn_with_expected_name: APIConnection,
|
|
resolve_host,
|
|
aiohappyeyeballs_start_connection,
|
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
|
event_loop = asyncio.get_running_loop()
|
|
transport = MagicMock()
|
|
connected = asyncio.Event()
|
|
|
|
with patch.object(
|
|
event_loop,
|
|
"create_connection",
|
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
|
):
|
|
connect_task = asyncio.create_task(
|
|
connect(conn_with_expected_name, login=False)
|
|
)
|
|
await connected.wait()
|
|
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
|
|
|
|
|
|
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
|
async def plaintext_connect_task_with_login(
|
|
conn_with_password: APIConnection,
|
|
resolve_host,
|
|
aiohappyeyeballs_start_connection,
|
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
|
transport = MagicMock()
|
|
connected = asyncio.Event()
|
|
event_loop = asyncio.get_running_loop()
|
|
|
|
with patch.object(
|
|
event_loop,
|
|
"create_connection",
|
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
|
):
|
|
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
|
|
await connected.wait()
|
|
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
|
|
|
|
|
|
@pytest_asyncio.fixture(name="api_client")
|
|
async def api_client(
|
|
resolve_host, aiohappyeyeballs_start_connection
|
|
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
|
event_loop = asyncio.get_running_loop()
|
|
protocol: APIPlaintextFrameHelper | None = None
|
|
transport = MagicMock()
|
|
connected = asyncio.Event()
|
|
client = APIClient(
|
|
address="mydevice.local",
|
|
port=6052,
|
|
password=None,
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
event_loop,
|
|
"create_connection",
|
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
|
),
|
|
patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection),
|
|
):
|
|
connect_task = asyncio.create_task(connect_client(client, login=False))
|
|
await connected.wait()
|
|
conn = client._connection
|
|
protocol = conn._frame_helper
|
|
send_plaintext_hello(protocol)
|
|
await connect_task
|
|
transport.reset_mock()
|
|
yield client, conn, transport, protocol
|