"""Test fixtures.""" from __future__ import annotations import asyncio import socket from dataclasses import replace from functools import partial from typing import Callable from unittest.mock import MagicMock, create_autospec, patch import pytest import pytest_asyncio from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi.client import APIClient, ConnectionParams from aioesphomeapi.connection import APIConnection from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr from aioesphomeapi.zeroconf import ZeroconfManager from .common import ( connect, connect_client, get_mock_async_zeroconf, send_plaintext_hello, ) KEEP_ALIVE_INTERVAL = 15.0 class PatchableAPIConnection(APIConnection): pass @pytest.fixture def async_zeroconf(): return get_mock_async_zeroconf() @pytest.fixture def resolve_host(): with patch("aioesphomeapi.host_resolver.async_resolve_host") as func: func.return_value = [ AddrInfo( family=socket.AF_INET, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP, sockaddr=IPv4Sockaddr("10.0.0.512", 6052), ) ] yield func @pytest.fixture def patchable_api_client() -> APIClient: class PatchableAPIClient(APIClient): pass cli = PatchableAPIClient( address="127.0.0.1", port=6052, password=None, ) return cli def get_mock_connection_params() -> ConnectionParams: return ConnectionParams( addresses=["fake.address"], port=6052, password=None, client_info="Tests client", keepalive=KEEP_ALIVE_INTERVAL, zeroconf_manager=ZeroconfManager(), noise_psk=None, expected_name=None, ) @pytest.fixture def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams: return get_mock_connection_params() def mock_on_stop(expected_disconnect: bool) -> None: pass @pytest.fixture def conn( event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams ) -> APIConnection: return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture def conn_with_password( event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams ) -> APIConnection: connection_params = replace(connection_params, password="password") return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture def noise_conn( event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams ) -> APIConnection: connection_params = replace( connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" ) return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture def conn_with_expected_name( event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams ) -> APIConnection: connection_params = replace(connection_params, expected_name="test") return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture() def aiohappyeyeballs_start_connection(event_loop: asyncio.AbstractEventLoop): with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) mock_socket.type = socket.SOCK_STREAM mock_socket.fileno.return_value = 1 mock_socket.getpeername.return_value = ("10.0.0.512", 323) func.return_value = mock_socket yield func def _create_mock_transport_protocol( transport: asyncio.Transport, connected: asyncio.Event, create_func: Callable[[], APIPlaintextFrameHelper], **kwargs, ) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]: protocol: APIPlaintextFrameHelper = create_func() protocol.connection_made(transport) connected.set() return transport, protocol @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: loop = asyncio.get_event_loop() transport = MagicMock() connected = asyncio.Event() with patch.object( loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() yield conn, transport, conn._frame_helper, connect_task @pytest_asyncio.fixture(name="plaintext_connect_task_expected_name") async def plaintext_connect_task_no_login_with_expected_name( conn_with_expected_name: APIConnection, resolve_host, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: event_loop = asyncio.get_running_loop() transport = MagicMock() connected = asyncio.Event() with patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task( connect(conn_with_expected_name, login=False) ) await connected.wait() yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task @pytest_asyncio.fixture(name="plaintext_connect_task_with_login") async def plaintext_connect_task_with_login( conn_with_password: APIConnection, resolve_host, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() event_loop = asyncio.get_running_loop() with patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn_with_password, login=True)) await connected.wait() yield conn_with_password, transport, conn_with_password._frame_helper, connect_task @pytest_asyncio.fixture(name="api_client") async def api_client( resolve_host, aiohappyeyeballs_start_connection ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: event_loop = asyncio.get_running_loop() protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() client = APIClient( address="mydevice.local", port=6052, password=None, ) with ( patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ), patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), ): connect_task = asyncio.create_task(connect_client(client, login=False)) await connected.wait() conn = client._connection protocol = conn._frame_helper send_plaintext_hello(protocol) await connect_task transport.reset_mock() yield client, conn, transport, protocol