fix mocking
This commit is contained in:
parent
ddacc3e412
commit
ee41f045d0
|
@ -6,7 +6,7 @@ import socket
|
|||
from dataclasses import replace
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -119,7 +119,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
|
|||
@pytest.fixture()
|
||||
def aiohappyeyeballs_start_connection():
|
||||
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
|
||||
func.return_value = MagicMock(type=socket.SOCK_STREAM)
|
||||
mock_socket = Mock()
|
||||
mock_socket.type = socket.SOCK_STREAM
|
||||
mock_socket.fileno.return_value = 1
|
||||
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
|
||||
func.return_value = mock_socket
|
||||
yield func
|
||||
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@ import asyncio
|
|||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
import socket
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
|
@ -219,9 +220,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
|||
cli = APIClient("1.2.3.4", 1234, None)
|
||||
asyncio.get_event_loop()
|
||||
|
||||
async def _start_connection_with_delay(*args, **kwargs):
|
||||
await asyncio.sleep(1)
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||
return mock_socket
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
side_effect=partial(asyncio.sleep, 1),
|
||||
_start_connection_with_delay,
|
||||
):
|
||||
start_task = asyncio.create_task(cli.start_connection())
|
||||
await asyncio.sleep(0)
|
||||
|
@ -232,8 +239,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
|
|||
await start_task
|
||||
assert cli._connection is None
|
||||
|
||||
async def _start_connection_without_delay(*args, **kwargs):
|
||||
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
|
||||
mock_socket.getpeername.return_value = ("4.3.3.3", 323)
|
||||
return mock_socket
|
||||
|
||||
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
|
||||
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
|
||||
_start_connection_without_delay,
|
||||
):
|
||||
await cli.start_connection()
|
||||
await asyncio.sleep(0)
|
||||
|
@ -894,7 +907,7 @@ async def test_noise_psk_handles_subclassed_string():
|
|||
)
|
||||
# Make sure its not a subclassed string
|
||||
assert type(cli._params.noise_psk) is str
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert type(cli._params.expected_name) is str
|
||||
|
||||
rl = ReconnectLogic(
|
||||
|
@ -930,7 +943,7 @@ async def test_no_noise_psk():
|
|||
)
|
||||
# Make sure its not a subclassed string
|
||||
assert cli._params.noise_psk is None
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert type(cli._params.expected_name) is str
|
||||
|
||||
|
||||
|
@ -945,7 +958,7 @@ async def test_empty_noise_psk_or_expected_name():
|
|||
expected_name="",
|
||||
)
|
||||
assert cli._params.noise_psk is None
|
||||
assert type(cli._params.address) is str
|
||||
assert type(cli._params.addresses[0]) is str
|
||||
assert cli._params.expected_name is None
|
||||
|
||||
|
||||
|
|
|
@ -571,7 +571,8 @@ async def test_connect_resolver_times_out(
|
|||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
), pytest.raises(
|
||||
ResolveAPIError, match="Timeout while resolving IP address for fake.address"
|
||||
ResolveAPIError,
|
||||
match=r"Timeout while resolving IP address for \['fake.address'\]",
|
||||
):
|
||||
await connect(conn, login=False)
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
|
|||
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
|
||||
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host(["example.local"], 6052)
|
||||
ret = await hr.async_resolve_host(["example.com"], 6052)
|
||||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||
|
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
|
|||
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_addr.return_value = []
|
||||
with pytest.raises(APIConnectionError):
|
||||
await hr.async_resolve_host(["example.local"], 6052)
|
||||
await hr.async_resolve_host(["example.com"], 6052)
|
||||
|
||||
resolve_zc.assert_not_called()
|
||||
resolve_addr.assert_called_once_with("example.com", 6052)
|
||||
|
|
Loading…
Reference in New Issue