mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-08 19:38:09 +01:00
Add test for resolver timing out while connecting (#713)
This commit is contained in:
parent
66e654084b
commit
99380487a5
@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
@ -28,6 +29,7 @@ from aioesphomeapi.core import (
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
|
||||
@ -44,7 +46,7 @@ from .common import (
|
||||
send_plaintext_hello,
|
||||
utcnow,
|
||||
)
|
||||
from .conftest import KEEP_ALIVE_INTERVAL
|
||||
from .conftest import KEEP_ALIVE_INTERVAL, _create_mock_transport_protocol
|
||||
|
||||
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
|
||||
|
||||
@ -361,9 +363,7 @@ async def test_plaintext_connection_fails_handshake(
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
exception, raised_exception = exception_map
|
||||
protocol = get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
@ -373,13 +373,6 @@ async def test_plaintext_connection_fails_handshake(
|
||||
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
|
||||
raise exception
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
@ -390,11 +383,14 @@ async def test_plaintext_connection_fails_handshake(
|
||||
"aioesphomeapi.connection.APIPlaintextFrameHelper",
|
||||
APIPlaintextFrameHelperHandshakeException,
|
||||
), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol = conn._frame_helper
|
||||
assert conn._socket is not None
|
||||
assert conn._frame_helper is not None
|
||||
|
||||
@ -534,6 +530,26 @@ async def test_force_disconnect_fails(
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resolver_times_out(
|
||||
conn: APIConnection, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.async_resolve_host",
|
||||
side_effect=asyncio.TimeoutError,
|
||||
), patch.object(event_loop, "sock_connect"), patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
), pytest.raises(
|
||||
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
|
||||
):
|
||||
await connect(conn, login=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_fails_to_send_response(
|
||||
connection_params: ConnectionParams,
|
||||
@ -542,7 +558,6 @@ async def test_disconnect_fails_to_send_response(
|
||||
socket_socket,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
client = APIClient(
|
||||
@ -556,20 +571,16 @@ async def test_disconnect_fails_to_send_response(
|
||||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(
|
||||
connect_client(client, login=False, on_stop=_on_stop)
|
||||
)
|
||||
await connected.wait()
|
||||
protocol = client._connection._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
@ -597,7 +608,6 @@ async def test_disconnect_success_case(
|
||||
socket_socket,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
client = APIClient(
|
||||
@ -611,20 +621,16 @@ async def test_disconnect_success_case(
|
||||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(
|
||||
connect_client(client, login=False, on_stop=_on_stop)
|
||||
)
|
||||
await connected.wait()
|
||||
protocol = client._connection._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
|
Loading…
Reference in New Issue
Block a user