Refactor tests for new pytest-asyncio (#820)

This commit is contained in:
J. Nick Koston 2024-02-04 21:30:19 -06:00 committed by GitHub
parent 999dc5ab28
commit 5b4bdb8716
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 28 additions and 25 deletions

View File

@ -77,7 +77,7 @@ def get_mock_connection_params() -> ConnectionParams:
@pytest.fixture @pytest.fixture
def connection_params() -> ConnectionParams: def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams:
return get_mock_connection_params() return get_mock_connection_params()
@ -86,18 +86,24 @@ def mock_on_stop(expected_disconnect: bool) -> None:
@pytest.fixture @pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection: def conn(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
return PatchableAPIConnection(connection_params, mock_on_stop, True, None) return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture @pytest.fixture
def conn_with_password(connection_params: ConnectionParams) -> APIConnection: def conn_with_password(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace(connection_params, password="password") connection_params = replace(connection_params, password="password")
return PatchableAPIConnection(connection_params, mock_on_stop, True, None) return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture @pytest.fixture
def noise_conn(connection_params: ConnectionParams) -> APIConnection: def noise_conn(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace( connection_params = replace(
connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
) )
@ -105,13 +111,15 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
@pytest.fixture @pytest.fixture
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection: def conn_with_expected_name(
event_loop: asyncio.AbstractEventLoop, connection_params: ConnectionParams
) -> APIConnection:
connection_params = replace(connection_params, expected_name="test") connection_params = replace(connection_params, expected_name="test")
return PatchableAPIConnection(connection_params, mock_on_stop, True, None) return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture() @pytest.fixture()
def aiohappyeyeballs_start_connection(): def aiohappyeyeballs_start_connection(event_loop: asyncio.AbstractEventLoop):
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM mock_socket.type = socket.SOCK_STREAM
@ -137,7 +145,6 @@ def _create_mock_transport_protocol(
async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login(
conn: APIConnection, conn: APIConnection,
resolve_host, resolve_host,
event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -158,9 +165,9 @@ async def plaintext_connect_task_no_login(
async def plaintext_connect_task_no_login_with_expected_name( async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, conn_with_expected_name: APIConnection,
resolve_host, resolve_host,
event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
event_loop = asyncio.get_running_loop()
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
@ -180,11 +187,11 @@ async def plaintext_connect_task_no_login_with_expected_name(
async def plaintext_connect_task_with_login( async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, conn_with_password: APIConnection,
resolve_host, resolve_host,
event_loop,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
event_loop = asyncio.get_running_loop()
with patch.object( with patch.object(
event_loop, event_loop,
@ -198,8 +205,9 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client") @pytest_asyncio.fixture(name="api_client")
async def api_client( async def api_client(
resolve_host, event_loop, aiohappyeyeballs_start_connection resolve_host, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
event_loop = asyncio.get_running_loop()
protocol: APIPlaintextFrameHelper | None = None protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()

View File

@ -202,7 +202,6 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error(
"""Verify finish_connect re-wraps exceptions as UnhandledAPIError.""" """Verify finish_connect re-wraps exceptions as UnhandledAPIError."""
cli = APIClient("1.2.3.4", 1234, None) cli = APIClient("1.2.3.4", 1234, None)
asyncio.get_event_loop()
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
await cli.start_connection() await cli.start_connection()
@ -262,7 +261,7 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_while_handshaking(event_loop) -> None: async def test_request_while_handshaking() -> None:
"""Test trying a request while handshaking raises.""" """Test trying a request while handshaking raises."""
class PatchableApiClient(APIClient): class PatchableApiClient(APIClient):

View File

@ -652,10 +652,11 @@ async def test_force_disconnect_fails(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_resolver_times_out( async def test_connect_resolver_times_out(
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection conn: APIConnection, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
event_loop = asyncio.get_running_loop()
with patch( with patch(
"aioesphomeapi.host_resolver.async_resolve_host", "aioesphomeapi.host_resolver.async_resolve_host",
@ -674,7 +675,6 @@ async def test_connect_resolver_times_out(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_disconnect_fails_to_send_response( async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams, connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host, resolve_host,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> None: ) -> None:
@ -724,11 +724,10 @@ async def test_disconnect_fails_to_send_response(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_disconnect_success_case( async def test_disconnect_success_case(
connection_params: ConnectionParams, connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host, resolve_host,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
client = APIClient( client = APIClient(

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import socket import socket
from ipaddress import IPv6Address, ip_address from ipaddress import IPv6Address, ip_address
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -103,7 +104,8 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_getaddrinfo(event_loop, addr_infos): async def test_resolve_host_getaddrinfo(addr_infos):
event_loop = asyncio.get_running_loop()
with patch.object(event_loop, "getaddrinfo") as mock: with patch.object(event_loop, "getaddrinfo") as mock:
mock.return_value = [ mock.return_value = [
( (
@ -128,7 +130,8 @@ async def test_resolve_host_getaddrinfo(event_loop, addr_infos):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolve_host_getaddrinfo_oserror(event_loop): async def test_resolve_host_getaddrinfo_oserror():
event_loop = asyncio.get_running_loop()
with patch.object(event_loop, "getaddrinfo") as mock: with patch.object(event_loop, "getaddrinfo") as mock:
mock.side_effect = OSError() mock.side_effect = OSError()
with pytest.raises(APIConnectionError): with pytest.raises(APIConnectionError):

View File

@ -31,7 +31,6 @@ from .common import (
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_log_runner( async def test_log_runner(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection, conn: APIConnection,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
): ):
@ -97,7 +96,6 @@ async def test_log_runner(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_log_runner_reconnects_on_disconnect( async def test_log_runner_reconnects_on_disconnect(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection, conn: APIConnection,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
@ -175,7 +173,6 @@ async def test_log_runner_reconnects_on_disconnect(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_log_runner_reconnects_on_subscribe_failure( async def test_log_runner_reconnects_on_subscribe_failure(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection, conn: APIConnection,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,

View File

@ -672,9 +672,7 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handling_unexpected_disconnect( async def test_handling_unexpected_disconnect(aiohappyeyeballs_start_connection):
event_loop: asyncio.AbstractEventLoop, aiohappyeyeballs_start_connection
):
"""Test the disconnect callback fires with expected_disconnect=False.""" """Test the disconnect callback fires with expected_disconnect=False."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None protocol: APIPlaintextFrameHelper | None = None
@ -748,7 +746,6 @@ async def test_handling_unexpected_disconnect(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_backoff_on_encryption_error( async def test_backoff_on_encryption_error(
event_loop: asyncio.AbstractEventLoop,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
aiohappyeyeballs_start_connection, aiohappyeyeballs_start_connection,
) -> None: ) -> None: