From 8743b7e49c1d8e7ae5fa425f5d6d4c16affd8276 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 9 Dec 2023 17:58:42 -1000 Subject: [PATCH] cleanup more tests --- tests/conftest.py | 47 ++++++++++++++++++++++++---------------- tests/test_client.py | 8 +++---- tests/test_connection.py | 14 +++++------- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 367608f..45d46a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,6 +114,13 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio return PatchableAPIConnection(connection_params, mock_on_stop, True, None) +@pytest.fixture() +def aiohappyeyeballs_start_connection(): + with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: + func.return_value = MagicMock(type=socket.SOCK_STREAM) + yield func + + def _create_mock_transport_protocol( transport: asyncio.Transport, connected: asyncio.Event, @@ -128,15 +135,17 @@ def _create_mock_transport_protocol( @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") async def plaintext_connect_task_no_login( - conn: APIConnection, resolve_host, socket_socket, event_loop + conn: APIConnection, + resolve_host, + socket_socket, + event_loop, + aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: loop = asyncio.get_event_loop() transport = MagicMock() connected = asyncio.Event() - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), @@ -148,14 +157,16 @@ async def plaintext_connect_task_no_login( @pytest_asyncio.fixture(name="plaintext_connect_task_expected_name") async def plaintext_connect_task_no_login_with_expected_name( - conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop + conn_with_expected_name: APIConnection, + resolve_host, + socket_socket, + event_loop, + aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), @@ -169,14 +180,16 @@ async def plaintext_connect_task_no_login_with_expected_name( @pytest_asyncio.fixture(name="plaintext_connect_task_with_login") async def plaintext_connect_task_with_login( - conn_with_password: APIConnection, resolve_host, socket_socket, event_loop + conn_with_password: APIConnection, + resolve_host, + socket_socket, + event_loop, + aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), @@ -188,7 +201,7 @@ async def plaintext_connect_task_with_login( @pytest_asyncio.fixture(name="api_client") async def api_client( - resolve_host, socket_socket, event_loop + resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() @@ -199,15 +212,11 @@ async def api_client( password=None, ) - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), - ), patch( - "aioesphomeapi.client.APIConnection", PatchableAPIConnection - ): + ), patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): connect_task = asyncio.create_task(connect_client(client, login=False)) await connected.wait() conn = client._connection diff --git a/tests/test_client.py b/tests/test_client.py index 809c028..313559d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -194,14 +194,14 @@ async def test_connect_backwards_compat() -> None: @pytest.mark.asyncio -async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> None: +async def test_finish_connection_wraps_exceptions_as_unhandled_api_error( + aiohappyeyeballs_start_connection, +) -> None: """Verify finish_connect re-wraps exceptions as UnhandledAPIError.""" cli = APIClient("1.2.3.4", 1234, None) asyncio.get_event_loop() - with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ): + with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): await cli.start_connection() with patch.object( diff --git a/tests/test_connection.py b/tests/test_connection.py index 19b9f23..73cee88 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -558,7 +558,7 @@ async def test_force_disconnect_fails( @pytest.mark.asyncio async def test_connect_resolver_times_out( - conn: APIConnection, socket_socket, event_loop + conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() @@ -566,8 +566,6 @@ async def test_connect_resolver_times_out( with patch( "aioesphomeapi.host_resolver.async_resolve_host", side_effect=asyncio.TimeoutError, - ), patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" ), patch.object( event_loop, "create_connection", @@ -584,6 +582,7 @@ async def test_disconnect_fails_to_send_response( event_loop: asyncio.AbstractEventLoop, resolve_host, socket_socket, + aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() transport = MagicMock() @@ -599,9 +598,7 @@ async def test_disconnect_fails_to_send_response( nonlocal expected_disconnect expected_disconnect = _expected_disconnect - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), @@ -636,6 +633,7 @@ async def test_disconnect_success_case( event_loop: asyncio.AbstractEventLoop, resolve_host, socket_socket, + aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() transport = MagicMock() @@ -651,9 +649,7 @@ async def test_disconnect_success_case( nonlocal expected_disconnect expected_disconnect = _expected_disconnect - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" - ), patch.object( + with patch.object( loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected),