diff --git a/tests/test_connection.py b/tests/test_connection.py index e26f270..3652b4b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,6 +4,7 @@ import asyncio import logging from collections.abc import Coroutine from datetime import timedelta +from functools import partial from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -28,6 +29,7 @@ from aioesphomeapi.core import ( HandshakeAPIError, InvalidAuthAPIError, RequiresEncryptionAPIError, + ResolveAPIError, TimeoutAPIError, ) @@ -44,7 +46,7 @@ from .common import ( send_plaintext_hello, utcnow, ) -from .conftest import KEEP_ALIVE_INTERVAL +from .conftest import KEEP_ALIVE_INTERVAL, _create_mock_transport_protocol KEEP_ALIVE_TIMEOUT_RATIO = 4.5 @@ -361,9 +363,7 @@ async def test_plaintext_connection_fails_handshake( """ loop = asyncio.get_event_loop() exception, raised_exception = exception_map - protocol = get_mock_protocol(conn) messages = [] - protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() @@ -373,13 +373,6 @@ async def test_plaintext_connection_fails_handshake( def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]: raise exception - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - def on_msg(msg): messages.append(msg) @@ -390,11 +383,14 @@ async def test_plaintext_connection_fails_handshake( "aioesphomeapi.connection.APIPlaintextFrameHelper", APIPlaintextFrameHelperHandshakeException, ), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() + protocol = conn._frame_helper assert conn._socket is not None assert conn._frame_helper is not None @@ -534,6 +530,26 @@ async def test_force_disconnect_fails( await asyncio.sleep(0) +@pytest.mark.asyncio +async def test_connect_resolver_times_out( + conn: APIConnection, socket_socket, event_loop +) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: + transport = MagicMock() + connected = asyncio.Event() + + with patch( + "aioesphomeapi.host_resolver.async_resolve_host", + side_effect=asyncio.TimeoutError, + ), patch.object(event_loop, "sock_connect"), patch.object( + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ), pytest.raises( + ResolveAPIError, match="Timeout while resolving IP address for fake.address" + ): + await connect(conn, login=False) + + @pytest.mark.asyncio async def test_disconnect_fails_to_send_response( connection_params: ConnectionParams, @@ -542,7 +558,6 @@ async def test_disconnect_fails_to_send_response( socket_socket, ) -> None: loop = asyncio.get_event_loop() - protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() client = APIClient( @@ -556,20 +571,16 @@ async def test_disconnect_fails_to_send_response( nonlocal expected_disconnect expected_disconnect = _expected_disconnect - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task( connect_client(client, login=False, on_stop=_on_stop) ) await connected.wait() + protocol = client._connection._frame_helper send_plaintext_hello(protocol) await connect_task transport.reset_mock() @@ -597,7 +608,6 @@ async def test_disconnect_success_case( socket_socket, ) -> None: loop = asyncio.get_event_loop() - protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() client = APIClient( @@ -611,20 +621,16 @@ async def test_disconnect_success_case( nonlocal expected_disconnect expected_disconnect = _expected_disconnect - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task( connect_client(client, login=False, on_stop=_on_stop) ) await connected.wait() + protocol = client._connection._frame_helper send_plaintext_hello(protocol) await connect_task transport.reset_mock()