fix socket mocking

This commit is contained in:
J. Nick Koston 2023-12-12 11:08:37 -10:00
parent 04dcd13b27
commit 38d4d1d1c3
No known key found for this signature in database
3 changed files with 12 additions and 26 deletions

View File

@ -6,7 +6,7 @@ import socket
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
import pytest_asyncio
@ -50,12 +50,6 @@ def resolve_host():
yield func
@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func
@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
@ -119,7 +113,7 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
mock_socket = Mock()
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
@ -143,7 +137,6 @@ def _create_mock_transport_protocol(
async def plaintext_connect_task_no_login(
conn: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -165,7 +158,6 @@ async def plaintext_connect_task_no_login(
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -188,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
@ -207,7 +198,7 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
resolve_host, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()

View File

@ -170,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion):
client._connection.api_version = version
def test_expected_name(auth_client: APIClient) -> None:
@pytest.mark.asyncio
async def test_expected_name(auth_client: APIClient) -> None:
"""Ensure expected name can be set externally."""
assert auth_client.expected_name is None
auth_client.expected_name = "awesome"

View File

@ -221,7 +221,7 @@ async def test_plaintext_connection(
@pytest.mark.asyncio
async def test_start_connection_socket_error(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of socket error during start connection."""
loop = asyncio.get_event_loop()
@ -238,7 +238,7 @@ async def test_start_connection_socket_error(
@pytest.mark.asyncio
async def test_start_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of start connection timing out."""
asyncio.get_event_loop()
@ -264,9 +264,7 @@ async def test_start_connection_times_out(
@pytest.mark.asyncio
async def test_start_connection_os_error(
conn: APIConnection, resolve_host, socket_socket
):
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
"""Test handling of start connection has an OSError."""
asyncio.get_event_loop()
@ -284,9 +282,7 @@ async def test_start_connection_os_error(
@pytest.mark.asyncio
async def test_start_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
):
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
"""Test handling of start connection is cancelled."""
asyncio.get_event_loop()
@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled(
@pytest.mark.asyncio
async def test_finish_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
"""Test handling of finishing connection being cancelled."""
loop = asyncio.get_event_loop()
@ -368,7 +364,7 @@ async def test_finish_connection_times_out(
async def test_plaintext_connection_fails_handshake(
conn: APIConnection,
resolve_host: AsyncMock,
socket_socket: MagicMock,
aiohappyeyeballs_start_connection: MagicMock,
exception_map: tuple[Exception, Exception],
) -> None:
"""Test that the frame helper is closed before the underlying socket.
@ -558,7 +554,7 @@ async def test_force_disconnect_fails(
@pytest.mark.asyncio
async def test_connect_resolver_times_out(
conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection
conn: APIConnection, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
@ -582,7 +578,6 @@ async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()
@ -633,7 +628,6 @@ async def test_disconnect_success_case(
connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
aiohappyeyeballs_start_connection,
) -> None:
loop = asyncio.get_event_loop()