Add test for resolver timing out while connecting (#713)

This commit is contained in:
J. Nick Koston 2023-11-25 10:08:34 -06:00 committed by GitHub
parent 66e654084b
commit 99380487a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,7 @@ import asyncio
import logging import logging
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import timedelta from datetime import timedelta
from functools import partial
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
@ -28,6 +29,7 @@ from aioesphomeapi.core import (
HandshakeAPIError, HandshakeAPIError,
InvalidAuthAPIError, InvalidAuthAPIError,
RequiresEncryptionAPIError, RequiresEncryptionAPIError,
ResolveAPIError,
TimeoutAPIError, TimeoutAPIError,
) )
@ -44,7 +46,7 @@ from .common import (
send_plaintext_hello, send_plaintext_hello,
utcnow, utcnow,
) )
from .conftest import KEEP_ALIVE_INTERVAL from .conftest import KEEP_ALIVE_INTERVAL, _create_mock_transport_protocol
KEEP_ALIVE_TIMEOUT_RATIO = 4.5 KEEP_ALIVE_TIMEOUT_RATIO = 4.5
@ -361,9 +363,7 @@ async def test_plaintext_connection_fails_handshake(
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
exception, raised_exception = exception_map exception, raised_exception = exception_map
protocol = get_mock_protocol(conn)
messages = [] messages = []
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
@ -373,13 +373,6 @@ async def test_plaintext_connection_fails_handshake(
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]: def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
raise exception raise exception
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg): def on_msg(msg):
messages.append(msg) messages.append(msg)
@ -390,11 +383,14 @@ async def test_plaintext_connection_fails_handshake(
"aioesphomeapi.connection.APIPlaintextFrameHelper", "aioesphomeapi.connection.APIPlaintextFrameHelper",
APIPlaintextFrameHelperHandshakeException, APIPlaintextFrameHelperHandshakeException,
), patch.object( ), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
protocol = conn._frame_helper
assert conn._socket is not None assert conn._socket is not None
assert conn._frame_helper is not None assert conn._frame_helper is not None
@ -534,6 +530,26 @@ async def test_force_disconnect_fails(
await asyncio.sleep(0) await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch(
"aioesphomeapi.host_resolver.async_resolve_host",
side_effect=asyncio.TimeoutError,
), patch.object(event_loop, "sock_connect"), patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
), pytest.raises(
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
):
await connect(conn, login=False)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_disconnect_fails_to_send_response( async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams, connection_params: ConnectionParams,
@ -542,7 +558,6 @@ async def test_disconnect_fails_to_send_response(
socket_socket, socket_socket,
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
client = APIClient( client = APIClient(
@ -556,20 +571,16 @@ async def test_disconnect_fails_to_send_response(
nonlocal expected_disconnect nonlocal expected_disconnect
expected_disconnect = _expected_disconnect expected_disconnect = _expected_disconnect
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task( connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop) connect_client(client, login=False, on_stop=_on_stop)
) )
await connected.wait() await connected.wait()
protocol = client._connection._frame_helper
send_plaintext_hello(protocol) send_plaintext_hello(protocol)
await connect_task await connect_task
transport.reset_mock() transport.reset_mock()
@ -597,7 +608,6 @@ async def test_disconnect_success_case(
socket_socket, socket_socket,
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
client = APIClient( client = APIClient(
@ -611,20 +621,16 @@ async def test_disconnect_success_case(
nonlocal expected_disconnect nonlocal expected_disconnect
expected_disconnect = _expected_disconnect expected_disconnect = _expected_disconnect
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task( connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop) connect_client(client, login=False, on_stop=_on_stop)
) )
await connected.wait() await connected.wait()
protocol = client._connection._frame_helper
send_plaintext_hello(protocol) send_plaintext_hello(protocol)
await connect_task await connect_task
transport.reset_mock() transport.reset_mock()