Refactor bluetooth client functions to reduce duplicate code (#629)

This commit is contained in:
J. Nick Koston 2023-11-10 17:14:00 -06:00 committed by GitHub
parent 0dfabef72f
commit b227f79dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 441 additions and 353 deletions

View File

@ -772,14 +772,11 @@ class APIClient:
async def bluetooth_device_pair(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDevicePairing:
self._check_authenticated()
msg_types = (
BluetoothDevicePairingResponse,
BluetoothDeviceConnectionResponse,
)
assert self._connection is not None
def predicate_func(msg: message.Message) -> bool:
if TYPE_CHECKING:
assert isinstance(msg, msg_types)
@ -791,86 +788,77 @@ class APIClient:
)
return True
return BluetoothDevicePairing.from_pb(
await self._bluetooth_device_request(
address,
BluetoothDeviceRequestType.PAIR,
predicate_func,
msg_types,
timeout,
)
)
async def bluetooth_device_unpair(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceUnpairing:
return BluetoothDeviceUnpairing.from_pb(
await self._bluetooth_device_request(
address,
BluetoothDeviceRequestType.UNPAIR,
lambda msg: msg.address == address,
(BluetoothDeviceUnpairingResponse,),
timeout,
)
)
async def bluetooth_device_clear_cache(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceClearCache:
return BluetoothDeviceClearCache.from_pb(
await self._bluetooth_device_request(
address,
BluetoothDeviceRequestType.CLEAR_CACHE,
lambda msg: msg.address == address,
(BluetoothDeviceClearCacheResponse,),
timeout,
)
)
async def bluetooth_device_disconnect(
self, address: int, timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT
) -> None:
"""Disconnect from a Bluetooth device."""
await self._bluetooth_device_request(
address,
BluetoothDeviceRequestType.DISCONNECT,
lambda msg: msg.address == address and not msg.connected,
(BluetoothDeviceConnectionResponse,),
timeout,
)
async def _bluetooth_device_request(
self,
address: int,
request_type: BluetoothDeviceRequestType,
predicate_func: Callable[[BluetoothDeviceConnectionResponse], bool],
msg_types: tuple[type[message.Message], ...],
timeout: float,
) -> message.Message:
self._check_authenticated()
assert self._connection is not None
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.PAIR
address=address,
request_type=request_type,
),
),
predicate_func,
predicate_func,
msg_types,
timeout=timeout,
)
return BluetoothDevicePairing.from_pb(response)
async def bluetooth_device_unpair(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceUnpairing:
self._check_authenticated()
assert self._connection is not None
def predicate_func(msg: BluetoothDeviceUnpairingResponse) -> bool:
return bool(msg.address == address)
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
),
),
predicate_func,
predicate_func,
(BluetoothDeviceUnpairingResponse,),
timeout=timeout,
)
return BluetoothDeviceUnpairing.from_pb(response)
async def bluetooth_device_clear_cache(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceClearCache:
self._check_authenticated()
assert self._connection is not None
def predicate_func(msg: BluetoothDeviceClearCacheResponse) -> bool:
return bool(msg.address == address)
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
),
),
predicate_func,
predicate_func,
(BluetoothDeviceClearCacheResponse,),
timeout=timeout,
)
return BluetoothDeviceClearCache.from_pb(response)
async def bluetooth_device_disconnect(
self, address: int, timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT
) -> None:
self._check_authenticated()
def predicate_func(msg: BluetoothDeviceConnectionResponse) -> bool:
return bool(msg.address == address and not msg.connected)
assert self._connection is not None
await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address,
request_type=BluetoothDeviceRequestType.DISCONNECT,
),
),
predicate_func,
predicate_func,
(BluetoothDeviceConnectionResponse,),
timeout=timeout,
timeout,
)
return response
async def bluetooth_gatt_get_services(
self, address: int

View File

@ -0,0 +1,5 @@
"""Init tests."""
import logging
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)

View File

@ -8,6 +8,10 @@ from unittest.mock import MagicMock
from zeroconf import Zeroconf
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
UTC = timezone.utc
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
# We use a partial here since it is implemented in native code
@ -15,6 +19,8 @@ _MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
utcnow: partial[datetime] = partial(datetime.now, UTC)
utcnow.__doc__ = "Get now in UTC time."
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
def get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
@ -24,6 +30,15 @@ class Estr(str):
"""A subclassed string."""
def generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
return (
b"\0"
+ _cached_varuint_to_bytes(len(msg))
+ _cached_varuint_to_bytes(type_)
+ msg
)
def as_utc(dattim: datetime) -> datetime:
"""Return a datetime as UTC time."""
if dattim.tzinfo == UTC:
@ -60,3 +75,9 @@ def async_fire_time_changed(
if fire_all or mock_seconds_into_future >= future_seconds:
task._run()
task.cancel()
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
await conn.start_connection()
await conn.finish_connection(login=login)

146
tests/conftest.py Normal file
View File

@ -0,0 +1,146 @@
"""Test fixtures."""
from __future__ import annotations
import asyncio
import socket
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from google.protobuf import message
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import HelloResponse
from aioesphomeapi.client import APIClient, ConnectionParams
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet
@pytest.fixture
def resolve_host():
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
func.return_value = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
yield func
@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func
@pytest.fixture
def connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",
port=6052,
password=None,
client_info="Tests client",
keepalive=15.0,
zeroconf_instance=None,
noise_psk=None,
expected_name=None,
)
@pytest.fixture
def conn(connection_params) -> APIConnection:
async def on_stop(expected_disconnect: bool) -> None:
pass
return APIConnection(connection_params, on_stop)
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login(
conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
yield conn, transport, protocol, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login(
conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=True))
await connected.wait()
yield conn, transport, protocol, connect_task
@pytest_asyncio.fixture(name="api_client")
async def api_client(
conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
client = APIClient(
address="mydevice.local",
port=6052,
password=None,
)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
client._connection = conn
await connect_task
transport.reset_mock()
yield client, conn, transport, protocol

View File

@ -1,11 +1,17 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mock import AsyncMock, MagicMock, patch
from google.protobuf import message
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import (
AlarmControlPanelCommandRequest,
BinarySensorStateResponse,
BluetoothDeviceClearCacheResponse,
BluetoothDeviceConnectionResponse,
BluetoothDevicePairingResponse,
BluetoothDeviceUnpairingResponse,
CameraImageRequest,
CameraImageResponse,
ClimateCommandRequest,
@ -25,7 +31,7 @@ from aioesphomeapi.api_pb2 import (
TextCommandRequest,
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.model import (
AlarmControlPanelCommand,
APIVersion,
@ -47,7 +53,12 @@ from aioesphomeapi.model import (
)
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import Estr, get_mock_zeroconf
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
)
@pytest.fixture
@ -652,3 +663,85 @@ async def test_empty_noise_psk_or_expected_name():
assert cli._params.noise_psk is None
assert type(cli._params.address) is str
assert cli._params.expected_name is None
@pytest.mark.asyncio
async def test_bluetooth_disconnect(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_disconnect."""
client, connection, transport, protocol = api_client
disconnect_task = asyncio.create_task(client.bluetooth_device_disconnect(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False
)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse],
)
)
await disconnect_task
@pytest.mark.asyncio
async def test_bluetooth_pair(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_pair."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDevicePairingResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse],
)
)
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_unpair(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_unpair."""
client, connection, transport, protocol = api_client
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse],
)
)
await unpair_task
@pytest.mark.asyncio
async def test_bluetooth_clear_cache(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_clear_cache."""
client, connection, transport, protocol = api_client
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
protocol.data_received(
generate_plaintext_packet(
response.SerializeToString(),
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse],
)
)
await clear_task

View File

@ -1,18 +1,15 @@
from __future__ import annotations
import asyncio
import logging
import socket
from collections.abc import Coroutine
from datetime import timedelta
from typing import Any, Coroutine, Generator, Optional
from unittest.mock import AsyncMock
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from google.protobuf import message
from mock import MagicMock, patch
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
ConnectResponse,
DeviceInfoResponse,
@ -20,69 +17,22 @@ from aioesphomeapi.api_pb2 import (
PingRequest,
PingResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.connection import APIConnection, ConnectionState
from aioesphomeapi.core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
TimeoutAPIError,
)
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import async_fire_time_changed, utcnow
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
await conn.start_connection()
await conn.finish_connection(login=login)
@pytest.fixture
def connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",
port=6052,
password=None,
client_info="Tests client",
keepalive=15.0,
zeroconf_instance=None,
noise_psk=None,
expected_name=None,
)
@pytest.fixture
def conn(connection_params) -> APIConnection:
async def on_stop(expected_disconnect: bool) -> None:
pass
return APIConnection(connection_params, on_stop)
@pytest.fixture
def resolve_host():
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
func.return_value = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
yield func
@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func
from .common import (
PROTO_TO_MESSAGE_TYPE,
async_fire_time_changed,
connect,
generate_plaintext_packet,
utcnow,
)
def _get_mock_protocol(conn: APIConnection):
@ -98,80 +48,49 @@ def _get_mock_protocol(conn: APIConnection):
@pytest.mark.asyncio
async def test_connect(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
bytes.fromhex(
"003602080110091a216d6173746572617672656c61792028657"
"370686f6d652076323032332e362e3329220d6d617374657261"
"7672656c6179"
)
async def test_connect(
plaintext_connect_task_no_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
]
) -> None:
"""Test that a plaintext connection works."""
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
protocol.data_received(
bytes.fromhex(
"003602080110091a216d6173746572617672656c61792028657"
"370686f6d652076323032332e362e3329220d6d617374657261"
"7672656c6179"
)
protocol.data_received(
bytes.fromhex(
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
"46323a33453a35453a36302208323032332e362e332a154a756e"
"20323820323032332c2031383a31323a3236320965737033322d"
"65766250506209457370726573736966"
)
)
protocol.data_received(
bytes.fromhex(
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
"46323a33453a35453a36302208323032332e362e332a154a756e"
"20323820323032332c2031383a31323a3236320965737033322d"
"65766250506209457370726573736966"
)
await connect_task
)
await connect_task
assert conn.is_connected
@pytest.mark.asyncio
async def test_timeout_sending_message(
conn: APIConnection,
resolve_host: Coroutine[Any, Any, AddrInfo],
socket_socket: Generator[Any, Any, None],
event_loop: asyncio.AbstractEventLoop,
plaintext_connect_task_no_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
caplog: pytest.LogCaptureFixture,
) -> None:
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
transport = MagicMock()
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
b"5stackatomproxy"
b"\x00\x00$"
b"\x00\x00\x04"
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
b"5stackatomproxy"
b"\x00\x00$"
b"\x00\x00\x04"
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
await connect_task
@ -192,31 +111,12 @@ async def test_timeout_sending_message(
@pytest.mark.asyncio
async def test_disconnect_when_not_fully_connected(
conn: APIConnection,
resolve_host: Coroutine[Any, Any, AddrInfo],
socket_socket: Generator[Any, Any, None],
event_loop: asyncio.AbstractEventLoop,
plaintext_connect_task_no_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
caplog: pytest.LogCaptureFixture,
) -> None:
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
transport = MagicMock()
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
# Only send the first part of the handshake
# so we are stuck in the middle of the connection process
@ -264,34 +164,20 @@ async def test_requires_encryption_propagates(conn: APIConnection):
@pytest.mark.asyncio
async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket):
async def test_plaintext_connection(
plaintext_connect_task_no_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that a plaintext connection works."""
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
@ -412,35 +298,19 @@ async def test_finish_connection_is_cancelled(
@pytest.mark.asyncio
async def test_finish_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
):
plaintext_connect_task_no_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test handling of finish connection timing out."""
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
@ -484,7 +354,7 @@ async def test_plaintext_connection_fails_handshake(
exception, raised_exception = exception_map
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
@ -563,97 +433,62 @@ async def test_plaintext_connection_fails_handshake(
await asyncio.sleep(0)
def _generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
return (
b"\0"
+ _cached_varuint_to_bytes(len(msg))
+ _cached_varuint_to_bytes(type_)
+ msg
@pytest.mark.asyncio
async def test_connect_wrong_password(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = True
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
@pytest.mark.asyncio
async def test_connect_wrong_password(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=True))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = True
connect_msg = connect_response.SerializeToString()
protocol.data_received(
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
_generate_plaintext_packet(
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
)
)
with pytest.raises(InvalidAuthAPIError):
await connect_task
with pytest.raises(InvalidAuthAPIError):
await connect_task
assert not conn.is_connected
@pytest.mark.asyncio
async def test_connect_correct_password(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
async def test_connect_correct_password(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=True))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = False
connect_msg = connect_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = False
connect_msg = connect_response.SerializeToString()
protocol.data_received(
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
)
protocol.data_received(
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
_generate_plaintext_packet(
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
)
)
await connect_task
await connect_task
assert conn.is_connected

View File

@ -1,8 +1,8 @@
import socket
from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mock import AsyncMock, MagicMock, patch
from zeroconf import DNSCache
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional
import pytest
@ -136,7 +136,7 @@ class DummyAPIModel(APIModelBase):
@dataclass(frozen=True)
class ListAPIModel(APIModelBase):
val: List[DummyAPIModel] = field(default_factory=list)
val: list[DummyAPIModel] = field(default_factory=list)
def test_api_model_base_converter():