mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Refactor bluetooth client functions to reduce duplicate code (#629)
This commit is contained in:
parent
0dfabef72f
commit
b227f79dad
@ -772,14 +772,11 @@ class APIClient:
|
||||
async def bluetooth_device_pair(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDevicePairing:
|
||||
self._check_authenticated()
|
||||
msg_types = (
|
||||
BluetoothDevicePairingResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
)
|
||||
|
||||
assert self._connection is not None
|
||||
|
||||
def predicate_func(msg: message.Message) -> bool:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg, msg_types)
|
||||
@ -791,86 +788,77 @@ class APIClient:
|
||||
)
|
||||
return True
|
||||
|
||||
return BluetoothDevicePairing.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
address,
|
||||
BluetoothDeviceRequestType.PAIR,
|
||||
predicate_func,
|
||||
msg_types,
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
|
||||
async def bluetooth_device_unpair(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceUnpairing:
|
||||
return BluetoothDeviceUnpairing.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
address,
|
||||
BluetoothDeviceRequestType.UNPAIR,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceUnpairingResponse,),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
|
||||
async def bluetooth_device_clear_cache(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceClearCache:
|
||||
return BluetoothDeviceClearCache.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
address,
|
||||
BluetoothDeviceRequestType.CLEAR_CACHE,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceClearCacheResponse,),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
|
||||
async def bluetooth_device_disconnect(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT
|
||||
) -> None:
|
||||
"""Disconnect from a Bluetooth device."""
|
||||
await self._bluetooth_device_request(
|
||||
address,
|
||||
BluetoothDeviceRequestType.DISCONNECT,
|
||||
lambda msg: msg.address == address and not msg.connected,
|
||||
(BluetoothDeviceConnectionResponse,),
|
||||
timeout,
|
||||
)
|
||||
|
||||
async def _bluetooth_device_request(
|
||||
self,
|
||||
address: int,
|
||||
request_type: BluetoothDeviceRequestType,
|
||||
predicate_func: Callable[[BluetoothDeviceConnectionResponse], bool],
|
||||
msg_types: tuple[type[message.Message], ...],
|
||||
timeout: float,
|
||||
) -> message.Message:
|
||||
self._check_authenticated()
|
||||
assert self._connection is not None
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.PAIR
|
||||
address=address,
|
||||
request_type=request_type,
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
msg_types,
|
||||
timeout=timeout,
|
||||
)
|
||||
return BluetoothDevicePairing.from_pb(response)
|
||||
|
||||
async def bluetooth_device_unpair(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceUnpairing:
|
||||
self._check_authenticated()
|
||||
|
||||
assert self._connection is not None
|
||||
|
||||
def predicate_func(msg: BluetoothDeviceUnpairingResponse) -> bool:
|
||||
return bool(msg.address == address)
|
||||
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
(BluetoothDeviceUnpairingResponse,),
|
||||
timeout=timeout,
|
||||
)
|
||||
return BluetoothDeviceUnpairing.from_pb(response)
|
||||
|
||||
async def bluetooth_device_clear_cache(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceClearCache:
|
||||
self._check_authenticated()
|
||||
|
||||
assert self._connection is not None
|
||||
|
||||
def predicate_func(msg: BluetoothDeviceClearCacheResponse) -> bool:
|
||||
return bool(msg.address == address)
|
||||
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
(BluetoothDeviceClearCacheResponse,),
|
||||
timeout=timeout,
|
||||
)
|
||||
return BluetoothDeviceClearCache.from_pb(response)
|
||||
|
||||
async def bluetooth_device_disconnect(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT
|
||||
) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
def predicate_func(msg: BluetoothDeviceConnectionResponse) -> bool:
|
||||
return bool(msg.address == address and not msg.connected)
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address,
|
||||
request_type=BluetoothDeviceRequestType.DISCONNECT,
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
(BluetoothDeviceConnectionResponse,),
|
||||
timeout=timeout,
|
||||
timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
async def bluetooth_gatt_get_services(
|
||||
self, address: int
|
||||
|
@ -0,0 +1,5 @@
|
||||
"""Init tests."""
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
@ -8,6 +8,10 @@ from unittest.mock import MagicMock
|
||||
|
||||
from zeroconf import Zeroconf
|
||||
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
|
||||
|
||||
UTC = timezone.utc
|
||||
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
|
||||
# We use a partial here since it is implemented in native code
|
||||
@ -15,6 +19,8 @@ _MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
|
||||
utcnow: partial[datetime] = partial(datetime.now, UTC)
|
||||
utcnow.__doc__ = "Get now in UTC time."
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
def get_mock_zeroconf() -> MagicMock:
|
||||
return MagicMock(spec=Zeroconf)
|
||||
@ -24,6 +30,15 @@ class Estr(str):
|
||||
"""A subclassed string."""
|
||||
|
||||
|
||||
def generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
|
||||
return (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(msg))
|
||||
+ _cached_varuint_to_bytes(type_)
|
||||
+ msg
|
||||
)
|
||||
|
||||
|
||||
def as_utc(dattim: datetime) -> datetime:
|
||||
"""Return a datetime as UTC time."""
|
||||
if dattim.tzinfo == UTC:
|
||||
@ -60,3 +75,9 @@ def async_fire_time_changed(
|
||||
if fire_all or mock_seconds_into_future >= future_seconds:
|
||||
task._run()
|
||||
task.cancel()
|
||||
|
||||
|
||||
async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
await conn.start_connection()
|
||||
await conn.finish_connection(login=login)
|
||||
|
146
tests/conftest.py
Normal file
146
tests/conftest.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Test fixtures."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import HelloResponse
|
||||
from aioesphomeapi.client import APIClient, ConnectionParams
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import PROTO_TO_MESSAGE_TYPE, connect, generate_plaintext_packet
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_host():
|
||||
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
|
||||
func.return_value = AddrInfo(
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
|
||||
)
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_socket():
|
||||
with patch("socket.socket") as func:
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
address="fake.address",
|
||||
port=6052,
|
||||
password=None,
|
||||
client_info="Tests client",
|
||||
keepalive=15.0,
|
||||
zeroconf_instance=None,
|
||||
noise_psk=None,
|
||||
expected_name=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(connection_params) -> APIConnection:
|
||||
async def on_stop(expected_disconnect: bool) -> None:
|
||||
pass
|
||||
|
||||
return APIConnection(connection_params, on_stop)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||
async def plaintext_connect_task_no_login(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
yield conn, transport, protocol, connect_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
||||
async def plaintext_connect_task_with_login(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=True))
|
||||
await connected.wait()
|
||||
yield conn, transport, protocol, connect_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="api_client")
|
||||
async def api_client(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
client = APIClient(
|
||||
address="mydevice.local",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
yield client, conn, transport, protocol
|
@ -1,11 +1,17 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mock import AsyncMock, MagicMock, patch
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
AlarmControlPanelCommandRequest,
|
||||
BinarySensorStateResponse,
|
||||
BluetoothDeviceClearCacheResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothDevicePairingResponse,
|
||||
BluetoothDeviceUnpairingResponse,
|
||||
CameraImageRequest,
|
||||
CameraImageResponse,
|
||||
ClimateCommandRequest,
|
||||
@ -25,7 +31,7 @@ from aioesphomeapi.api_pb2 import (
|
||||
TextCommandRequest,
|
||||
)
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.model import (
|
||||
AlarmControlPanelCommand,
|
||||
APIVersion,
|
||||
@ -47,7 +53,12 @@ from aioesphomeapi.model import (
|
||||
)
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
from .common import Estr, get_mock_zeroconf
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_zeroconf,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -652,3 +663,85 @@ async def test_empty_noise_psk_or_expected_name():
|
||||
assert cli._params.noise_psk is None
|
||||
assert type(cli._params.address) is str
|
||||
assert cli._params.expected_name is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_disconnect(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_disconnect."""
|
||||
client, connection, transport, protocol = api_client
|
||||
disconnect_task = asyncio.create_task(client.bluetooth_device_disconnect(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse],
|
||||
)
|
||||
)
|
||||
await disconnect_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_pair(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_pair."""
|
||||
client, connection, transport, protocol = api_client
|
||||
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDevicePairingResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse],
|
||||
)
|
||||
)
|
||||
await pair_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_unpair(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_unpair."""
|
||||
client, connection, transport, protocol = api_client
|
||||
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse],
|
||||
)
|
||||
)
|
||||
await unpair_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_clear_cache(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_clear_cache."""
|
||||
client, connection, transport, protocol = api_client
|
||||
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse],
|
||||
)
|
||||
)
|
||||
await clear_task
|
||||
|
@ -1,18 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from collections.abc import Coroutine
|
||||
from datetime import timedelta
|
||||
from typing import Any, Coroutine, Generator, Optional
|
||||
from unittest.mock import AsyncMock
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
ConnectResponse,
|
||||
DeviceInfoResponse,
|
||||
@ -20,69 +17,22 @@ from aioesphomeapi.api_pb2 import (
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
)
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
await conn.start_connection()
|
||||
await conn.finish_connection(login=login)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
address="fake.address",
|
||||
port=6052,
|
||||
password=None,
|
||||
client_info="Tests client",
|
||||
keepalive=15.0,
|
||||
zeroconf_instance=None,
|
||||
noise_psk=None,
|
||||
expected_name=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(connection_params) -> APIConnection:
|
||||
async def on_stop(expected_disconnect: bool) -> None:
|
||||
pass
|
||||
|
||||
return APIConnection(connection_params, on_stop)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_host():
|
||||
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
|
||||
func.return_value = AddrInfo(
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
|
||||
)
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_socket():
|
||||
with patch("socket.socket") as func:
|
||||
yield func
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
async_fire_time_changed,
|
||||
connect,
|
||||
generate_plaintext_packet,
|
||||
utcnow,
|
||||
)
|
||||
|
||||
|
||||
def _get_mock_protocol(conn: APIConnection):
|
||||
@ -98,80 +48,49 @@ def _get_mock_protocol(conn: APIConnection):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
protocol.data_received(
|
||||
bytes.fromhex(
|
||||
"003602080110091a216d6173746572617672656c61792028657"
|
||||
"370686f6d652076323032332e362e3329220d6d617374657261"
|
||||
"7672656c6179"
|
||||
)
|
||||
async def test_connect(
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
]
|
||||
) -> None:
|
||||
"""Test that a plaintext connection works."""
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
protocol.data_received(
|
||||
bytes.fromhex(
|
||||
"003602080110091a216d6173746572617672656c61792028657"
|
||||
"370686f6d652076323032332e362e3329220d6d617374657261"
|
||||
"7672656c6179"
|
||||
)
|
||||
protocol.data_received(
|
||||
bytes.fromhex(
|
||||
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
|
||||
"46323a33453a35453a36302208323032332e362e332a154a756e"
|
||||
"20323820323032332c2031383a31323a3236320965737033322d"
|
||||
"65766250506209457370726573736966"
|
||||
)
|
||||
)
|
||||
protocol.data_received(
|
||||
bytes.fromhex(
|
||||
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
|
||||
"46323a33453a35453a36302208323032332e362e332a154a756e"
|
||||
"20323820323032332c2031383a31323a3236320965737033322d"
|
||||
"65766250506209457370726573736966"
|
||||
)
|
||||
|
||||
await connect_task
|
||||
|
||||
)
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_sending_message(
|
||||
conn: APIConnection,
|
||||
resolve_host: Coroutine[Any, Any, AddrInfo],
|
||||
socket_socket: Generator[Any, Any, None],
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
b"5stackatomproxy"
|
||||
b"\x00\x00$"
|
||||
b"\x00\x00\x04"
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
b"5stackatomproxy"
|
||||
b"\x00\x00$"
|
||||
b"\x00\x00\x04"
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
|
||||
await connect_task
|
||||
|
||||
@ -192,31 +111,12 @@ async def test_timeout_sending_message(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_fully_connected(
|
||||
conn: APIConnection,
|
||||
resolve_host: Coroutine[Any, Any, AddrInfo],
|
||||
socket_socket: Generator[Any, Any, None],
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
|
||||
# Only send the first part of the handshake
|
||||
# so we are stuck in the middle of the connection process
|
||||
@ -264,34 +164,20 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket):
|
||||
async def test_plaintext_connection(
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that a plaintext connection works."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
)
|
||||
@ -412,35 +298,19 @@ async def test_finish_connection_is_cancelled(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_times_out(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test handling of finish connection timing out."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
messages = []
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
)
|
||||
@ -484,7 +354,7 @@ async def test_plaintext_connection_fails_handshake(
|
||||
exception, raised_exception = exception_map
|
||||
protocol = _get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
@ -563,97 +433,62 @@ async def test_plaintext_connection_fails_handshake(
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
|
||||
return (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(msg))
|
||||
+ _cached_varuint_to_bytes(type_)
|
||||
+ msg
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_wrong_password(
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = True
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_wrong_password(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=True))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = True
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(
|
||||
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidAuthAPIError):
|
||||
await connect_task
|
||||
with pytest.raises(InvalidAuthAPIError):
|
||||
await connect_task
|
||||
|
||||
assert not conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_correct_password(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
async def test_connect_correct_password(
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=True))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = False
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = False
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(
|
||||
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
|
||||
)
|
||||
)
|
||||
|
||||
await connect_task
|
||||
await connect_task
|
||||
|
||||
assert conn.is_connected
|
||||
|
@ -1,8 +1,8 @@
|
||||
import socket
|
||||
from ipaddress import ip_address
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mock import AsyncMock, MagicMock, patch
|
||||
from zeroconf import DNSCache
|
||||
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -136,7 +136,7 @@ class DummyAPIModel(APIModelBase):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListAPIModel(APIModelBase):
|
||||
val: List[DummyAPIModel] = field(default_factory=list)
|
||||
val: list[DummyAPIModel] = field(default_factory=list)
|
||||
|
||||
|
||||
def test_api_model_base_converter():
|
||||
|
Loading…
Reference in New Issue
Block a user