fix mocking

This commit is contained in:
J. Nick Koston 2023-12-12 10:28:53 -10:00
parent ddacc3e412
commit ee41f045d0
No known key found for this signature in database
4 changed files with 29 additions and 11 deletions

View File

@ -6,7 +6,7 @@ import socket
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
@ -119,7 +119,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM)
mock_socket = Mock()
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
func.return_value = mock_socket
yield func

View File

@ -4,9 +4,10 @@ import asyncio
import contextlib
import itertools
import logging
import socket
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
import pytest
from google.protobuf import message
@ -219,9 +220,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
cli = APIClient("1.2.3.4", 1234, None)
asyncio.get_event_loop()
async def _start_connection_with_delay(*args, **kwargs):
await asyncio.sleep(1)
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
_start_connection_with_delay,
):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
@ -232,8 +239,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
await start_task
assert cli._connection is None
async def _start_connection_without_delay(*args, **kwargs):
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
return mock_socket
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
_start_connection_without_delay,
):
await cli.start_connection()
await asyncio.sleep(0)
@ -894,7 +907,7 @@ async def test_noise_psk_handles_subclassed_string():
)
# Make sure its not a subclassed string
assert type(cli._params.noise_psk) is str
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str
rl = ReconnectLogic(
@ -930,7 +943,7 @@ async def test_no_noise_psk():
)
# Make sure its not a subclassed string
assert cli._params.noise_psk is None
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert type(cli._params.expected_name) is str
@ -945,7 +958,7 @@ async def test_empty_noise_psk_or_expected_name():
expected_name="",
)
assert cli._params.noise_psk is None
assert type(cli._params.address) is str
assert type(cli._params.addresses[0]) is str
assert cli._params.expected_name is None

View File

@ -571,7 +571,8 @@ async def test_connect_resolver_times_out(
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
), pytest.raises(
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
ResolveAPIError,
match=r"Timeout while resolving IP address for \['fake.address'\]",
):
await connect(conn, login=False)

View File

@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host(["example.local"], 6052)
ret = await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = []
with pytest.raises(APIConnectionError):
await hr.async_resolve_host(["example.local"], 6052)
await hr.async_resolve_host(["example.com"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)