Add test for resolver timing out while connecting (#713)

This commit is contained in:
J. Nick Koston 2023-11-25 10:08:34 -06:00 committed by GitHub
parent 66e654084b
commit 99380487a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import asyncio
import logging
from collections.abc import Coroutine
from datetime import timedelta
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
@ -28,6 +29,7 @@ from aioesphomeapi.core import (
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
TimeoutAPIError,
)
@ -44,7 +46,7 @@ from .common import (
send_plaintext_hello,
utcnow,
)
from .conftest import KEEP_ALIVE_INTERVAL
from .conftest import KEEP_ALIVE_INTERVAL, _create_mock_transport_protocol
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
@ -361,9 +363,7 @@ async def test_plaintext_connection_fails_handshake(
"""
loop = asyncio.get_event_loop()
exception, raised_exception = exception_map
protocol = get_mock_protocol(conn)
messages = []
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
@ -373,13 +373,6 @@ async def test_plaintext_connection_fails_handshake(
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
raise exception
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
@ -390,11 +383,14 @@ async def test_plaintext_connection_fails_handshake(
"aioesphomeapi.connection.APIPlaintextFrameHelper",
APIPlaintextFrameHelperHandshakeException,
), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol = conn._frame_helper
assert conn._socket is not None
assert conn._frame_helper is not None
@ -534,6 +530,26 @@ async def test_force_disconnect_fails(
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch(
"aioesphomeapi.host_resolver.async_resolve_host",
side_effect=asyncio.TimeoutError,
), patch.object(event_loop, "sock_connect"), patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
), pytest.raises(
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
):
await connect(conn, login=False)
@pytest.mark.asyncio
async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams,
@ -542,7 +558,6 @@ async def test_disconnect_fails_to_send_response(
socket_socket,
) -> None:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
client = APIClient(
@ -556,20 +571,16 @@ async def test_disconnect_fails_to_send_response(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop)
)
await connected.wait()
protocol = client._connection._frame_helper
send_plaintext_hello(protocol)
await connect_task
transport.reset_mock()
@ -597,7 +608,6 @@ async def test_disconnect_success_case(
socket_socket,
) -> None:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
client = APIClient(
@ -611,20 +621,16 @@ async def test_disconnect_success_case(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop)
)
await connected.wait()
protocol = client._connection._frame_helper
send_plaintext_hello(protocol)
await connect_task
transport.reset_mock()