fix socket mocking
This commit is contained in:
parent
04dcd13b27
commit
38d4d1d1c3
|
@ -6,7 +6,7 @@ import socket
|
|||
from dataclasses import replace
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -50,12 +50,6 @@ def resolve_host():
|
|||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_socket():
|
||||
with patch("socket.socket") as func:
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patchable_api_client() -> APIClient:
|
||||
class PatchableAPIClient(APIClient):
|
||||
|
@ -119,7 +113,7 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
|
|||
@pytest.fixture()
|
||||
def aiohappyeyeballs_start_connection():
|
||||
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
||||
mock_socket = Mock()
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.type = socket.SOCK_STREAM
|
||||
mock_socket.fileno.return_value = 1
|
||||
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
|
||||
|
@ -143,7 +137,6 @@ def _create_mock_transport_protocol(
|
|||
async def plaintext_connect_task_no_login(
|
||||
conn: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
|
@ -165,7 +158,6 @@ async def plaintext_connect_task_no_login(
|
|||
async def plaintext_connect_task_no_login_with_expected_name(
|
||||
conn_with_expected_name: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
|
@ -188,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
|
|||
async def plaintext_connect_task_with_login(
|
||||
conn_with_password: APIConnection,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
event_loop,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
|
@ -207,7 +198,7 @@ async def plaintext_connect_task_with_login(
|
|||
|
||||
@pytest_asyncio.fixture(name="api_client")
|
||||
async def api_client(
|
||||
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
resolve_host, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
|
|
|
@ -170,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
|
|||
client._connection.api_version = version
|
||||
|
||||
|
||||
def test_expected_name(auth_client: APIClient) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expected_name(auth_client: APIClient) -> None:
|
||||
"""Ensure expected name can be set externally."""
|
||||
assert auth_client.expected_name is None
|
||||
auth_client.expected_name = "awesome"
|
||||
|
|
|
@ -221,7 +221,7 @@ async def test_plaintext_connection(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_socket_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of socket error during start connection."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_times_out(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of start connection timing out."""
|
||||
asyncio.get_event_loop()
|
||||
|
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_os_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
|
||||
"""Test handling of start connection has an OSError."""
|
||||
asyncio.get_event_loop()
|
||||
|
||||
|
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
|
||||
"""Test handling of start connection is cancelled."""
|
||||
asyncio.get_event_loop()
|
||||
|
||||
|
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
|
||||
):
|
||||
"""Test handling of finishing connection being cancelled."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
|
|||
async def test_plaintext_connection_fails_handshake(
|
||||
conn: APIConnection,
|
||||
resolve_host: AsyncMock,
|
||||
socket_socket: MagicMock,
|
||||
aiohappyeyeballs_start_connection: MagicMock,
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
"""Test that the frame helper is closed before the underlying socket.
|
||||
|
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resolver_times_out(
|
||||
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
|
||||
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
@ -582,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
|
|||
connection_params: ConnectionParams,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -633,7 +628,6 @@ async def test_disconnect_success_case(
|
|||
connection_params: ConnectionParams,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
aiohappyeyeballs_start_connection,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
Loading…
Reference in New Issue