diff --git a/tests/conftest.py b/tests/conftest.py index a39e267..c936a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,6 +49,19 @@ def socket_socket(): yield func +@pytest.fixture +def patchable_api_client() -> APIClient: + class PatchableAPIClient(APIClient): + pass + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + return cli + + def get_mock_connection_params() -> ConnectionParams: return ConnectionParams( address="fake.address", diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index d5694fe..e735d71 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import logging +from functools import partial from ipaddress import ip_address from unittest.mock import AsyncMock, MagicMock, patch @@ -28,6 +29,7 @@ from .common import ( send_plaintext_connect_response, send_plaintext_hello, ) +from .conftest import _create_mock_transport_protocol logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG) @@ -207,20 +209,12 @@ async def test_reconnect_logic_state(): @pytest.mark.asyncio -async def test_reconnect_retry(): +async def test_reconnect_retry(patchable_api_client: APIClient): """Test that reconnect logic retry.""" on_disconnect_called = [] on_connect_called = [] on_connect_fail_called = [] - - class PatchableAPIClient(APIClient): - pass - - cli = PatchableAPIClient( - address="1.2.3.4", - port=6052, - password=None, - ) + cli = patchable_api_client async def on_disconnect(expected_disconnect: bool) -> None: nonlocal on_disconnect_called @@ -375,13 +369,9 @@ async def test_reconnect_zeroconf( @pytest.mark.asyncio -async def test_reconnect_logic_stop_callback(): +async def test_reconnect_logic_stop_callback(patchable_api_client: APIClient): """Test that the stop_callback stops the ReconnectLogic.""" - cli = APIClient( - address="1.2.3.4", - port=6052, - password=None, - ) + cli = patchable_api_client rl = ReconnectLogic( client=cli, on_disconnect=AsyncMock(), @@ -403,17 +393,11 @@ async def test_reconnect_logic_stop_callback(): @pytest.mark.asyncio -async def test_reconnect_logic_stop_callback_waits_for_handshake(): +async def test_reconnect_logic_stop_callback_waits_for_handshake( + patchable_api_client: APIClient, +): """Test that the stop_callback waits for a handshake.""" - - class PatchableAPIClient(APIClient): - pass - - cli = PatchableAPIClient( - address="1.2.3.4", - port=6052, - password=None, - ) + cli = patchable_api_client rl = ReconnectLogic( client=cli, on_disconnect=AsyncMock(), @@ -473,13 +457,6 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL zeroconf_instance=async_zeroconf.zeroconf, ) - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - connected = asyncio.Event() on_disconnect_calls = [] @@ -498,20 +475,23 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL ) with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): await logic.start() await connected.wait() protocol = cli._connection._frame_helper send_plaintext_hello(protocol) send_plaintext_connect_response(protocol, False) - await connected.wait() assert cli._connection.is_connected is True await asyncio.sleep(0) with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ) as mock_create_connection: protocol.eof_received() # Wait for the task to run