Remove duplicate code in tests (#722)

This commit is contained in:
J. Nick Koston 2023-11-26 07:53:48 -06:00 committed by GitHub
parent d19ae94a5e
commit e8468647e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 36 deletions

View File

@ -49,6 +49,19 @@ def socket_socket():
yield func
@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
return cli
def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import logging
from functools import partial
from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch
@ -28,6 +29,7 @@ from .common import (
send_plaintext_connect_response,
send_plaintext_hello,
)
from .conftest import _create_mock_transport_protocol
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
@ -207,20 +209,12 @@ async def test_reconnect_logic_state():
@pytest.mark.asyncio
async def test_reconnect_retry():
async def test_reconnect_retry(patchable_api_client: APIClient):
"""Test that reconnect logic retry."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
@ -375,13 +369,9 @@ async def test_reconnect_zeroconf(
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback():
async def test_reconnect_logic_stop_callback(patchable_api_client: APIClient):
"""Test that the stop_callback stops the ReconnectLogic."""
cli = APIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
@ -403,17 +393,11 @@ async def test_reconnect_logic_stop_callback():
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback_waits_for_handshake():
async def test_reconnect_logic_stop_callback_waits_for_handshake(
patchable_api_client: APIClient,
):
"""Test that the stop_callback waits for a handshake."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
@ -473,13 +457,6 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
zeroconf_instance=async_zeroconf.zeroconf,
)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
connected = asyncio.Event()
on_disconnect_calls = []
@ -498,20 +475,23 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connected.wait()
assert cli._connection.is_connected is True
await asyncio.sleep(0)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
) as mock_create_connection:
protocol.eof_received()
# Wait for the task to run