From 5b4bdb87167dd768c7a45934d84cac9880e93337 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 4 Feb 2024 21:30:19 -0600 Subject: [PATCH] Refactor tests for new pytest-asyncio (#820) --- tests/conftest.py | 28 ++++++++++++++++++---------- tests/test_client.py | 3 +-- tests/test_connection.py | 7 +++---- tests/test_host_resolver.py | 7 +++++-- tests/test_log_runner.py | 3 --- tests/test_reconnect_logic.py | 5 +---- 6 files changed, 28 insertions(+), 25 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4e25771..33815da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,7 +77,7 @@ def get_mock_connection_params() -> ConnectionParams: @pytest.fixture -def connection_params() -> ConnectionParams: +def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams: return get_mock_connection_params() @@ -86,18 +86,24 @@ def mock_on_stop(expected_disconnect: bool) -> None: @pytest.fixture -def conn(connection_params: ConnectionParams) -> APIConnection: +def conn( + event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams +) -> APIConnection: return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture -def conn_with_password(connection_params: ConnectionParams) -> APIConnection: +def conn_with_password( + event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams +) -> APIConnection: connection_params = replace(connection_params, password="password") return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture -def noise_conn(connection_params: ConnectionParams) -> APIConnection: +def noise_conn( + event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams +) -> APIConnection: connection_params = replace( connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" ) @@ -105,13 +111,15 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection: @pytest.fixture -def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection: +def conn_with_expected_name( + event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams +) -> APIConnection: connection_params = replace(connection_params, expected_name="test") return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture() -def aiohappyeyeballs_start_connection(): +def aiohappyeyeballs_start_connection(event_loop: asyncio.AbstractEventLoop): with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) mock_socket.type = socket.SOCK_STREAM @@ -137,7 +145,6 @@ def _create_mock_transport_protocol( async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, - event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: loop = asyncio.get_event_loop() @@ -158,9 +165,9 @@ async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login_with_expected_name( conn_with_expected_name: APIConnection, resolve_host, - event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: + event_loop = asyncio.get_running_loop() transport = MagicMock() connected = asyncio.Event() @@ -180,11 +187,11 @@ async def plaintext_connect_task_no_login_with_expected_name( async def plaintext_connect_task_with_login( conn_with_password: APIConnection, resolve_host, - event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() + event_loop = asyncio.get_running_loop() with patch.object( event_loop, @@ -198,8 +205,9 @@ async def plaintext_connect_task_with_login( @pytest_asyncio.fixture(name="api_client") async def api_client( - resolve_host, event_loop, aiohappyeyeballs_start_connection + resolve_host, aiohappyeyeballs_start_connection ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: + event_loop = asyncio.get_running_loop() protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() diff --git a/tests/test_client.py b/tests/test_client.py index 3466e73..cbd86c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -202,7 +202,6 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error( """Verify finish_connect re-wraps exceptions as UnhandledAPIError.""" cli = APIClient("1.2.3.4", 1234, None) - asyncio.get_event_loop() with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): await cli.start_connection() @@ -262,7 +261,7 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: @pytest.mark.asyncio -async def test_request_while_handshaking(event_loop) -> None: +async def test_request_while_handshaking() -> None: """Test trying a request while handshaking raises.""" class PatchableApiClient(APIClient): diff --git a/tests/test_connection.py b/tests/test_connection.py index 0e60130..8c31ab7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -652,10 +652,11 @@ async def test_force_disconnect_fails( @pytest.mark.asyncio async def test_connect_resolver_times_out( - conn: APIConnection, event_loop, aiohappyeyeballs_start_connection + conn: APIConnection, aiohappyeyeballs_start_connection ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() + event_loop = asyncio.get_running_loop() with patch( "aioesphomeapi.host_resolver.async_resolve_host", @@ -674,7 +675,6 @@ async def test_connect_resolver_times_out( @pytest.mark.asyncio async def test_disconnect_fails_to_send_response( connection_params: ConnectionParams, - event_loop: asyncio.AbstractEventLoop, resolve_host, aiohappyeyeballs_start_connection, ) -> None: @@ -724,11 +724,10 @@ async def test_disconnect_fails_to_send_response( @pytest.mark.asyncio async def test_disconnect_success_case( connection_params: ConnectionParams, - event_loop: asyncio.AbstractEventLoop, resolve_host, aiohappyeyeballs_start_connection, ) -> None: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() transport = MagicMock() connected = asyncio.Event() client = APIClient( diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 82f1095..5b333e2 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import socket from ipaddress import IPv6Address, ip_address from unittest.mock import AsyncMock, MagicMock, patch @@ -103,7 +104,8 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc @pytest.mark.asyncio -async def test_resolve_host_getaddrinfo(event_loop, addr_infos): +async def test_resolve_host_getaddrinfo(addr_infos): + event_loop = asyncio.get_running_loop() with patch.object(event_loop, "getaddrinfo") as mock: mock.return_value = [ ( @@ -128,7 +130,8 @@ async def test_resolve_host_getaddrinfo(event_loop, addr_infos): @pytest.mark.asyncio -async def test_resolve_host_getaddrinfo_oserror(event_loop): +async def test_resolve_host_getaddrinfo_oserror(): + event_loop = asyncio.get_running_loop() with patch.object(event_loop, "getaddrinfo") as mock: mock.side_effect = OSError() with pytest.raises(APIConnectionError): diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index 1fdf826..c6725b4 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -31,7 +31,6 @@ from .common import ( @pytest.mark.asyncio async def test_log_runner( - event_loop: asyncio.AbstractEventLoop, conn: APIConnection, aiohappyeyeballs_start_connection, ): @@ -97,7 +96,6 @@ async def test_log_runner( @pytest.mark.asyncio async def test_log_runner_reconnects_on_disconnect( - event_loop: asyncio.AbstractEventLoop, conn: APIConnection, caplog: pytest.LogCaptureFixture, aiohappyeyeballs_start_connection, @@ -175,7 +173,6 @@ async def test_log_runner_reconnects_on_disconnect( @pytest.mark.asyncio async def test_log_runner_reconnects_on_subscribe_failure( - event_loop: asyncio.AbstractEventLoop, conn: APIConnection, caplog: pytest.LogCaptureFixture, aiohappyeyeballs_start_connection, diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index 7b8f67b..29b39b5 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -672,9 +672,7 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake( @pytest.mark.asyncio -async def test_handling_unexpected_disconnect( - event_loop: asyncio.AbstractEventLoop, aiohappyeyeballs_start_connection -): +async def test_handling_unexpected_disconnect(aiohappyeyeballs_start_connection): """Test the disconnect callback fires with expected_disconnect=False.""" loop = asyncio.get_event_loop() protocol: APIPlaintextFrameHelper | None = None @@ -748,7 +746,6 @@ async def test_handling_unexpected_disconnect( @pytest.mark.asyncio async def test_backoff_on_encryption_error( - event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture, aiohappyeyeballs_start_connection, ) -> None: