diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index e70acab..2f16b11 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -39,6 +39,6 @@ cdef class APIFrameHelper: @cython.locals(end_of_frame_pos="unsigned int") cdef void _remove_from_buffer(self) - cpdef write_packets(self, list packets, bint debug_enabled) + cpdef void write_packets(self, list packets, bint debug_enabled) - cdef void _write_bytes(self, bytes data, bint debug_enabled) + cdef void _write_bytes(self, object data, bint debug_enabled) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 8ed4392..597f41b 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -22,6 +22,7 @@ SOCKET_ERRORS = ( WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError) _int = int +_bytes = bytes class APIFrameHelper: @@ -196,7 +197,7 @@ class APIFrameHelper: def resume_writing(self) -> None: """Stub.""" - def _write_bytes(self, data: bytes, debug_enabled: bool) -> None: + def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None: """Write bytes to the socket.""" if debug_enabled: _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex()) diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index 1604de7..c264eb3 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -29,7 +29,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper): msg_size_high="unsigned char", msg_size_low="unsigned char", ) - cpdef data_received(self, object data) + cpdef void data_received(self, object data) @cython.locals( msg=bytes, @@ -64,6 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper): frame=bytes, frame_len=cython.uint, ) - cpdef write_packets(self, list packets, bint debug_enabled) + cpdef void write_packets(self, list packets, bint debug_enabled) cdef _error_on_incorrect_preamble(self, bytes msg) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 8622ff7..41ea06f 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -140,8 +140,6 @@ class APINoiseFrameHelper(APIFrameHelper): if (header := self._read(3)) is None: return preamble = header[0] - msg_size_high = header[1] - msg_size_low = header[2] if preamble != 0x01: self._handle_error_and_close( ProtocolAPIError( @@ -149,6 +147,8 @@ class APINoiseFrameHelper(APIFrameHelper): ) ) return + msg_size_high = header[1] + msg_size_low = header[2] if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None: # The complete frame is not yet available, wait for more data # to arrive before continuing, since callback_packet has not diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index 382b3b8..ac1a93d 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -5,12 +5,14 @@ from .base cimport APIFrameHelper cdef object varuint_to_bytes +cdef bytes EMPTY_PACKET +cdef bint TYPE_CHECKING cpdef _varuint_to_bytes(cython.int value) cdef class APIPlaintextFrameHelper(APIFrameHelper): - cpdef data_received(self, object data) + cpdef void data_received(self, object data) cdef void _error_on_incorrect_preamble(self, int preamble) @@ -20,4 +22,4 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper): packet=tuple, type_=object ) - cpdef write_packets(self, list packets, bint debug_enabled) + cpdef void write_packets(self, list packets, bint debug_enabled) diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 7a4beee..7be442b 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -2,12 +2,15 @@ from __future__ import annotations import asyncio from functools import lru_cache +from typing import TYPE_CHECKING from ..core import ProtocolAPIError, RequiresEncryptionAPIError from .base import APIFrameHelper _int = int +EMPTY_PACKET = b"" + def _varuint_to_bytes(value: _int) -> bytes: """Convert a varuint to bytes.""" @@ -71,17 +74,19 @@ class APIPlaintextFrameHelper(APIFrameHelper): if (msg_type := self._read_varuint()) == -1: return + packet_data: bytes | None if length == 0: - packet_data = b"" + packet_data = EMPTY_PACKET else: # The packet data is not yet available, wait for more data # to arrive before continuing, since callback_packet has not # been called yet the buffer will not be cleared and the next # call to data_received will continue processing the packet # at the start of the frame. - if (maybe_packet_data := self._read(length)) is None: + if (packet_data := self._read(length)) is None: return - packet_data = maybe_packet_data + if TYPE_CHECKING: + assert packet_data is not None, "Packet data should be set" self._remove_from_buffer() self._connection.process_packet(msg_type, packet_data) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 6f0d689..37c3440 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -65,15 +65,16 @@ from .api_pb2 import ( # type: ignore SwitchCommandRequest, TextCommandRequest, UnsubscribeBluetoothLEAdvertisementsRequest, - VoiceAssistantAudioSettings, VoiceAssistantEventData, VoiceAssistantEventResponse, VoiceAssistantRequest, VoiceAssistantResponse, ) from .client_callbacks import ( + handle_timeout, on_ble_raw_advertisement_response, on_bluetooth_connections_free_response, + on_bluetooth_device_connection_response, on_bluetooth_gatt_notify_data_response, on_bluetooth_le_advertising_response, on_home_assistant_service_response, @@ -116,9 +117,9 @@ from .model import ( MediaPlayerCommand, UserService, UserServiceArgType, - VoiceAssistantCommand, - VoiceAssistantEventType, ) +from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel +from .model import VoiceAssistantCommand, VoiceAssistantEventType from .model_conversions import ( LIST_ENTITIES_SERVICES_RESPONSE_TYPES, SUBSCRIBE_STATES_RESPONSE_TYPES, @@ -467,7 +468,7 @@ class APIClient: timeout: float = 10.0, ) -> message.Message: message_filter = partial(self._filter_bluetooth_message, address, handle) - resp = await self._get_connection().send_messages_await_response_complex( + [resp] = await self._get_connection().send_messages_await_response_complex( (request,), message_filter, message_filter, @@ -475,10 +476,21 @@ class APIClient: timeout, ) - if isinstance(resp[0], BluetoothGATTErrorResponse): - raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp[0])) + if ( + type(resp) # pylint: disable=unidiomatic-typecheck + is BluetoothGATTErrorResponse + ): + raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp)) - return resp[0] + return resp + + def _unsub_bluetooth_advertisements( + self, unsub_callback: Callable[[], None] + ) -> None: + """Unsubscribe Bluetooth advertisements if connected.""" + if self._connection is not None: + unsub_callback() + self._connection.send_message(UnsubscribeBluetoothLEAdvertisementsRequest()) async def subscribe_bluetooth_le_advertisements( self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None] @@ -491,15 +503,7 @@ class APIClient: ), (BluetoothLEAdvertisementResponse,), ) - - def unsub() -> None: - if self._connection is not None: - unsub_callback() - self._connection.send_message( - UnsubscribeBluetoothLEAdvertisementsRequest() - ) - - return unsub + return partial(self._unsub_bluetooth_advertisements, unsub_callback) async def subscribe_bluetooth_le_raw_advertisements( self, on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None] @@ -511,15 +515,7 @@ class APIClient: partial(on_ble_raw_advertisement_response, on_advertisements), (BluetoothLERawAdvertisementsResponse,), ) - - def unsub() -> None: - if self._connection is not None: - unsub_callback() - self._connection.send_message( - UnsubscribeBluetoothLEAdvertisementsRequest() - ) - - return unsub + return partial(self._unsub_bluetooth_advertisements, unsub_callback) async def subscribe_bluetooth_connections_free( self, on_bluetooth_connections_free_update: Callable[[int, int], None] @@ -533,28 +529,6 @@ class APIClient: (BluetoothConnectionsFreeResponse,), ) - def _handle_timeout(self, fut: asyncio.Future[None]) -> None: - """Handle a timeout.""" - if fut.done(): - return - fut.set_exception(asyncio.TimeoutError) - - def _on_bluetooth_device_connection_response( - self, - connect_future: asyncio.Future[None], - address: int, - on_bluetooth_connection_state: Callable[[bool, int, int], None], - msg: BluetoothDeviceConnectionResponse, - ) -> None: - """Handle a BluetoothDeviceConnectionResponse message.""" "" - if address == msg.address: - on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error) - # Resolve on ANY connection state since we do not want - # to wait the whole timeout if the device disconnects - # or we get an error. - if not connect_future.done(): - connect_future.set_result(None) - async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches self, address: int, @@ -589,7 +563,7 @@ class APIClient: address_type=address_type or 0, ), partial( - self._on_bluetooth_device_connection_response, + on_bluetooth_device_connection_response, connect_future, address, on_bluetooth_connection_state, @@ -599,7 +573,7 @@ class APIClient: loop = self._loop timeout_handle = loop.call_at( - loop.time() + timeout, self._handle_timeout, connect_future + loop.time() + timeout, handle_timeout, connect_future ) timeout_expired = False connect_ok = False @@ -607,6 +581,11 @@ class APIClient: await connect_future connect_ok = True except asyncio.TimeoutError as err: + # If the timeout expires, make sure + # to unsub before calling _bluetooth_device_disconnect_guard_timeout + # so that the disconnect message is not propagated back to the caller + # since we are going to raise a TimeoutAPIError. + unsub() timeout_expired = True # Disconnect before raising the exception to ensure # the slot is recovered before the timeout is raised @@ -625,14 +604,8 @@ class APIClient: f" after {disconnect_timeout}s" ) from err finally: - if not connect_ok: - try: - unsub() - except (KeyError, ValueError): - _LOGGER.warning( - "%s: Bluetooth device connection canceled but already unsubscribed", - to_human_readable_address(address), - ) + if not connect_ok and not timeout_expired: + unsub() if not timeout_expired: timeout_handle.cancel() @@ -667,7 +640,7 @@ class APIClient: return False if isinstance(msg, BluetoothDeviceConnectionResponse): raise APIConnectionError( - "Peripheral changed connections status while pairing" + f"Peripheral changed connections status while pairing: {msg.error}" ) return True @@ -1270,7 +1243,8 @@ class APIClient: async def subscribe_voice_assistant( self, handle_start: Callable[ - [str, int, VoiceAssistantAudioSettings], Coroutine[Any, Any, int | None] + [str, int, VoiceAssistantAudioSettingsModel], + Coroutine[Any, Any, int | None], ], handle_stop: Callable[[], Coroutine[Any, Any, None]], ) -> Callable[[], None]: @@ -1297,6 +1271,8 @@ class APIClient: self._connection.send_message(VoiceAssistantResponse(error=True)) def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None: + nonlocal start_task + command = VoiceAssistantCommand.from_pb(msg) if command.start: start_task = asyncio.create_task( @@ -1319,6 +1295,8 @@ class APIClient: ) def unsub() -> None: + nonlocal start_task + if self._connection is not None: remove_callback() self._connection.send_message( @@ -1333,20 +1311,15 @@ class APIClient: def send_voice_assistant_event( self, event_type: VoiceAssistantEventType, data: dict[str, str] | None ) -> None: - req = VoiceAssistantEventResponse() - req.event_type = event_type - - data_args = [] + req = VoiceAssistantEventResponse(event_type=event_type) if data is not None: - for name, value in data.items(): - arg = VoiceAssistantEventData() - arg.name = name - arg.value = value - data_args.append(arg) - - # pylint: disable=no-member - req.data.extend(data_args) - + # pylint: disable=no-member + req.data.extend( + [ + VoiceAssistantEventData(name=name, value=value) + for name, value in data.items() + ] + ) self._get_connection().send_message(req) async def alarm_control_panel_command( diff --git a/aioesphomeapi/client_callbacks.pxd b/aioesphomeapi/client_callbacks.pxd index 2611a2a..8fad3d9 100644 --- a/aioesphomeapi/client_callbacks.pxd +++ b/aioesphomeapi/client_callbacks.pxd @@ -8,3 +8,5 @@ cdef object CameraImageResponse, CameraState cdef object HomeassistantServiceCall cdef object BluetoothLEAdvertisement + +cdef object asyncio_TimeoutError diff --git a/aioesphomeapi/client_callbacks.py b/aioesphomeapi/client_callbacks.py index 483d8f4..c846c86 100644 --- a/aioesphomeapi/client_callbacks.py +++ b/aioesphomeapi/client_callbacks.py @@ -1,11 +1,14 @@ from __future__ import annotations +from asyncio import Future +from asyncio import TimeoutError as asyncio_TimeoutError from typing import TYPE_CHECKING, Callable from google.protobuf import message from .api_pb2 import ( # type: ignore BluetoothConnectionsFreeResponse, + BluetoothDeviceConnectionResponse, BluetoothGATTNotifyDataResponse, BluetoothLEAdvertisementResponse, BluetoothLERawAdvertisement, @@ -93,3 +96,25 @@ def on_subscribe_home_assistant_state_response( msg: SubscribeHomeAssistantStateResponse, ) -> None: on_state_sub(msg.entity_id, msg.attribute) + + +def handle_timeout(fut: Future[None]) -> None: + """Handle a timeout.""" + if not fut.done(): + fut.set_exception(asyncio_TimeoutError) + + +def on_bluetooth_device_connection_response( + connect_future: Future[None], + address: int, + on_bluetooth_connection_state: Callable[[bool, int, int], None], + msg: BluetoothDeviceConnectionResponse, +) -> None: + """Handle a BluetoothDeviceConnectionResponse message.""" "" + if address == msg.address: + on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error) + # Resolve on ANY connection state since we do not want + # to wait the whole timeout if the device disconnects + # or we get an error. + if not connect_future.done(): + connect_future.set_result(None) diff --git a/aioesphomeapi/model.py b/aioesphomeapi/model.py index 9b19ddb..394fe18 100644 --- a/aioesphomeapi/model.py +++ b/aioesphomeapi/model.py @@ -78,12 +78,13 @@ class APIModelBase: def from_dict( cls: type[_V], data: dict[str, Any], *, ignore_missing: bool = True ) -> _V: - init_args = { - f.name: data[f.name] - for f in cached_fields(cls) # type: ignore[arg-type] - if f.name in data or (not ignore_missing) - } - return cls(**init_args) + return cls( + **{ + f.name: data[f.name] + for f in cached_fields(cls) # type: ignore[arg-type] + if f.name in data or (not ignore_missing) + } + ) @classmethod def from_pb(cls: type[_V], data: Any) -> _V: diff --git a/setup.py b/setup.py index f2cad25..c0f5bd7 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file: long_description = readme_file.read() -VERSION = "19.1.0" +VERSION = "19.1.1" PROJECT_NAME = "aioesphomeapi" PROJECT_PACKAGE_NAME = "aioesphomeapi" PROJECT_LICENSE = "MIT" diff --git a/tests/conftest.py b/tests/conftest.py index c936a4a..e7d1085 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,12 @@ from aioesphomeapi.connection import APIConnection from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr from aioesphomeapi.zeroconf import ZeroconfManager -from .common import connect, get_mock_async_zeroconf, send_plaintext_hello +from .common import ( + connect, + connect_client, + get_mock_async_zeroconf, + send_plaintext_hello, +) KEEP_ALIVE_INTERVAL = 15.0 @@ -80,19 +85,19 @@ def connection_params() -> ConnectionParams: return get_mock_connection_params() -def on_stop(expected_disconnect: bool) -> None: +def mock_on_stop(expected_disconnect: bool) -> None: pass @pytest.fixture def conn(connection_params: ConnectionParams) -> APIConnection: - return PatchableAPIConnection(connection_params, on_stop, True, None) + return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture def conn_with_password(connection_params: ConnectionParams) -> APIConnection: connection_params = replace(connection_params, password="password") - return PatchableAPIConnection(connection_params, on_stop, True, None) + return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture @@ -100,13 +105,13 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection: connection_params = replace( connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" ) - return PatchableAPIConnection(connection_params, on_stop, True, None) + return PatchableAPIConnection(connection_params, mock_on_stop, True, None) @pytest.fixture def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection: connection_params = replace(connection_params, expected_name="test") - return PatchableAPIConnection(connection_params, on_stop, True, None) + return PatchableAPIConnection(connection_params, mock_on_stop, True, None) def _create_mock_transport_protocol( @@ -177,7 +182,7 @@ async def plaintext_connect_task_with_login( @pytest_asyncio.fixture(name="api_client") async def api_client( - conn: APIConnection, resolve_host, socket_socket, event_loop + resolve_host, socket_socket, event_loop ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() @@ -192,12 +197,12 @@ async def api_client( event_loop, "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), - ): - connect_task = asyncio.create_task(connect(conn, login=False)) + ), patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): + connect_task = asyncio.create_task(connect_client(client, login=False)) await connected.wait() + conn = client._connection protocol = conn._frame_helper send_plaintext_hello(protocol) - client._connection = conn await connect_task transport.reset_mock() yield client, conn, transport, protocol diff --git a/tests/test_client.py b/tests/test_client.py index 5c04c1d..9dd5ebb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import itertools import logging +from functools import partial from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -17,7 +18,10 @@ from aioesphomeapi.api_pb2 import ( BluetoothDeviceClearCacheResponse, BluetoothDeviceConnectionResponse, BluetoothDevicePairingResponse, + BluetoothDeviceRequest, BluetoothDeviceUnpairingResponse, + BluetoothGATTCharacteristic, + BluetoothGATTDescriptor, BluetoothGATTErrorResponse, BluetoothGATTGetServicesDoneResponse, BluetoothGATTGetServicesResponse, @@ -29,6 +33,7 @@ from aioesphomeapi.api_pb2 import ( BluetoothLEAdvertisementResponse, BluetoothLERawAdvertisement, BluetoothLERawAdvertisementsResponse, + BluetoothServiceData, ButtonCommandRequest, CameraImageRequest, CameraImageResponse, @@ -40,6 +45,7 @@ from aioesphomeapi.api_pb2 import ( ExecuteServiceRequest, FanCommandRequest, HomeassistantServiceResponse, + HomeAssistantStateResponse, LightCommandRequest, ListEntitiesBinarySensorResponse, ListEntitiesDoneResponse, @@ -51,8 +57,14 @@ from aioesphomeapi.api_pb2 import ( SirenCommandRequest, SubscribeHomeAssistantStateResponse, SubscribeLogsResponse, + SubscribeVoiceAssistantRequest, SwitchCommandRequest, TextCommandRequest, + VoiceAssistantAudioSettings, + VoiceAssistantEventData, + VoiceAssistantEventResponse, + VoiceAssistantRequest, + VoiceAssistantResponse, ) from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection @@ -67,10 +79,12 @@ from aioesphomeapi.model import ( APIVersion, BinarySensorInfo, BinarySensorState, + BluetoothDeviceRequestType, ) from aioesphomeapi.model import BluetoothGATTService as BluetoothGATTServiceModel from aioesphomeapi.model import ( BluetoothLEAdvertisement, + BluetoothProxyFeature, CameraState, ClimateFanMode, ClimateMode, @@ -88,6 +102,10 @@ from aioesphomeapi.model import ( UserServiceArg, UserServiceArgType, ) +from aioesphomeapi.model import ( + VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel, +) +from aioesphomeapi.model import VoiceAssistantEventType as VoiceAssistantEventModelType from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState from .common import ( @@ -194,6 +212,29 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No await cli.finish_connection(False) +@pytest.mark.asyncio +async def test_request_while_handshaking(event_loop) -> None: + """Test trying a request while handshaking raises.""" + + class PatchableApiClient(APIClient): + pass + + cli = PatchableApiClient("host", 1234, None) + with patch.object( + event_loop, "sock_connect", side_effect=partial(asyncio.sleep, 1) + ), patch.object(cli, "finish_connection"): + connect_task = asyncio.create_task(cli.connect()) + + await asyncio.sleep(0) + with pytest.raises( + APIConnectionError, match="Authenticated connection not ready yet" + ): + await cli.device_info() + + connect_task.cancel() + await asyncio.sleep(0) + + @pytest.mark.asyncio async def test_connect_while_already_connected(auth_client: APIClient) -> None: """Test connecting while already connected raises.""" @@ -895,11 +936,36 @@ async def test_bluetooth_pair( client, connection, transport, protocol = api_client pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) await asyncio.sleep(0) + response: message.Message = BluetoothDevicePairingResponse(address=4567) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert not pair_task.done() response: message.Message = BluetoothDevicePairingResponse(address=1234) mock_data_received(protocol, generate_plaintext_packet(response)) await pair_task +@pytest.mark.asyncio +async def test_bluetooth_pair_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_device_pair.""" + client, connection, transport, protocol = api_client + pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + with pytest.raises( + APIConnectionError, + match="Peripheral changed connections status while pairing: 13", + ): + await pair_task + + @pytest.mark.asyncio async def test_bluetooth_unpair( api_client: tuple[ @@ -938,7 +1004,7 @@ async def test_device_info( ) -> None: """Test fetching device info.""" client, connection, transport, protocol = api_client - assert client.log_name == "mydevice.local" + assert client.log_name == "fake @ 10.0.0.512" device_info_task = asyncio.create_task(client.device_info()) await asyncio.sleep(0) response: message.Message = DeviceInfoResponse( @@ -957,7 +1023,7 @@ async def test_device_info( response: message.Message = DisconnectResponse() mock_data_received(protocol, generate_plaintext_packet(response)) await disconnect_task - with pytest.raises(APIConnectionError, match="CLOSED"): + with pytest.raises(APIConnectionError, match="Not connected"): await client.device_info() @@ -984,6 +1050,24 @@ async def test_bluetooth_gatt_read( assert await read_task == b"1234" +@pytest.mark.asyncio +async def test_bluetooth_gatt_read_error( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_gatt_read that errors.""" + client, connection, transport, protocol = api_client + read_task = asyncio.create_task(client.bluetooth_gatt_read(1234, 1234)) + await asyncio.sleep(0) + error_response: message.Message = BluetoothGATTErrorResponse( + address=1234, handle=1234 + ) + mock_data_received(protocol, generate_plaintext_packet(error_response)) + with pytest.raises(BluetoothGATTAPIError): + await read_task + + @pytest.mark.asyncio async def test_bluetooth_gatt_read_descriptor( api_client: tuple[ @@ -1106,7 +1190,16 @@ async def test_bluetooth_gatt_get_services( services_task = asyncio.create_task(client.bluetooth_gatt_get_services(1234)) await asyncio.sleep(0) service1: message.Message = BluetoothGATTService( - uuid=[1, 1], handle=1, characteristics=[] + uuid=[1, 1], + handle=1, + characteristics=[ + BluetoothGATTCharacteristic( + uuid=[1, 2], + handle=2, + properties=1, + descriptors=[BluetoothGATTDescriptor(uuid=[1, 3], handle=3)], + ) + ], ) response: message.Message = BluetoothGATTGetServicesResponse( address=1234, services=[service1] @@ -1116,9 +1209,10 @@ async def test_bluetooth_gatt_get_services( mock_data_received(protocol, generate_plaintext_packet(done_response)) services = await services_task + service = BluetoothGATTServiceModel.from_pb(service1) assert services == ESPHomeBluetoothGATTServices( address=1234, - services=[BluetoothGATTServiceModel(uuid=[1, 1], handle=1, characteristics=[])], + services=[service], ) @@ -1196,6 +1290,10 @@ async def test_bluetooth_gatt_start_notify( # Ensure abort callback is a no-op after cancel # and doesn't raise abort_cb() + await client.disconnect(force=True) + # Ensure abort callback is a no-op after disconnect + # and does not raise + await cancel_cb() @pytest.mark.asyncio @@ -1250,8 +1348,18 @@ async def test_subscribe_bluetooth_le_advertisements( name=b"mydevice", rssi=-50, service_uuids=["1234"], - service_data={}, - manufacturer_data={}, + service_data=[ + BluetoothServiceData( + uuid="1234", + data=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ) + ], + manufacturer_data=[ + BluetoothServiceData( + uuid="1234", + data=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ) + ], address_type=1, ) mock_data_received(protocol, generate_plaintext_packet(response)) @@ -1262,6 +1370,33 @@ async def test_subscribe_bluetooth_le_advertisements( name="mydevice", rssi=-50, service_uuids=["000034-0000-1000-8000-00805f9b34fb"], + manufacturer_data={ + 4660: b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + }, + service_data={ + "000034-0000-1000-8000-00805f9b34fb": b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + }, + address_type=1, + ) + ] + advs.clear() + response: message.Message = BluetoothLEAdvertisementResponse( + address=1234, + name=b"mydevice", + rssi=-50, + service_uuids=[], + service_data=[], + manufacturer_data=[], + address_type=1, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + + assert advs == [ + BluetoothLEAdvertisement( + address=1234, + name="mydevice", + rssi=-50, + service_uuids=[], manufacturer_data={}, service_data={}, address_type=1, @@ -1370,6 +1505,17 @@ async def test_subscribe_logs(auth_client: APIClient) -> None: on_logs.assert_called_with(log_msg) +@pytest.mark.asyncio +async def test_send_home_assistant_state(auth_client: APIClient) -> None: + send = patch_send(auth_client) + await auth_client.send_home_assistant_state("binary_sensor.bla", None, "on") + send.assert_called_once_with( + HomeAssistantStateResponse( + entity_id="binary_sensor.bla", state="on", attribute=None + ) + ) + + @pytest.mark.asyncio async def test_subscribe_service_calls(auth_client: APIClient) -> None: send = patch_response_callback(auth_client) @@ -1398,7 +1544,7 @@ async def test_set_debug( caplog.set_level(logging.DEBUG) client.set_debug(True) - assert client.log_name == "mydevice.local" + assert client.log_name == "fake @ 10.0.0.512" device_info_task = asyncio.create_task(client.device_info()) await asyncio.sleep(0) mock_data_received(protocol, generate_plaintext_packet(response)) @@ -1419,11 +1565,446 @@ async def test_force_disconnect( api_client: tuple[ APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper ], - caplog: pytest.LogCaptureFixture, ) -> None: """Test force disconnect can be called multiple times.""" client, connection, transport, protocol = api_client + assert connection.is_connected is True + assert connection.on_stop is not None await client.disconnect(force=True) + assert client._connection is None assert connection.is_connected is False await client.disconnect(force=False) assert connection.is_connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("has_cache", "feature_flags", "method"), + [ + (False, BluetoothProxyFeature(0), BluetoothDeviceRequestType.CONNECT), + ( + False, + BluetoothProxyFeature.REMOTE_CACHING, + BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE, + ), + ( + True, + BluetoothProxyFeature.REMOTE_CACHING, + BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE, + ), + ], +) +async def test_bluetooth_device_connect( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + has_cache: bool, + feature_flags: BluetoothProxyFeature, + method: BluetoothDeviceRequestType, +) -> None: + """Test bluetooth_device_connect.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=1, + feature_flags=feature_flags, + has_cache=has_cache, + disconnect_timeout=1, + address_type=1, + ) + ) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=True, mtu=23, error=0 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + + cancel = await connect_task + assert states == [(True, 23, 0)] + transport.write.assert_called_once_with( + generate_plaintext_packet( + BluetoothDeviceRequest( + address=1234, + request_type=method, + has_address_type=True, + address_type=1, + ), + ) + ) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=7 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert states == [(True, 23, 0), (False, 23, 7)] + cancel() + + # After cancel, no more messages should called back + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=8 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert states == [(True, 23, 0), (False, 23, 7)] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_and_disconnect_times_out( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect and disconnect times out.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=0, + feature_flags=0, + has_cache=True, + disconnect_timeout=0, + address_type=1, + ) + ) + with pytest.raises(TimeoutAPIError): + await connect_task + assert states == [] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_times_out_disconnect_ok( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect and disconnect times out.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=0, + feature_flags=0, + has_cache=True, + disconnect_timeout=1, + address_type=1, + ) + ) + await asyncio.sleep(0) + # The connect request should be written + assert len(transport.write.mock_calls) == 1 + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + # Now that we timed out, the disconnect + # request should be written + assert len(transport.write.mock_calls) == 2 + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=8 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + with pytest.raises(TimeoutAPIError): + await connect_task + assert states == [] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_cancelled( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect handles cancellation.""" + client, connection, transport, protocol = api_client + states = [] + + handlers_before = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=10, + feature_flags=0, + has_cache=True, + disconnect_timeout=10, + address_type=1, + ) + ) + await asyncio.sleep(0) + # The connect request should be written + assert len(transport.write.mock_calls) == 1 + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + assert states == [] + + handlers_after = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + # Make sure we do not leak message handlers + assert handlers_after == handlers_before + + +@pytest.mark.asyncio +async def test_send_voice_assistant_event(auth_client: APIClient) -> None: + send = patch_send(auth_client) + + auth_client.send_voice_assistant_event( + VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR, + {"error": "error", "ok": "ok"}, + ) + send.assert_called_once_with( + VoiceAssistantEventResponse( + event_type=VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR.value, + data=[ + VoiceAssistantEventData(name="error", value="error"), + VoiceAssistantEventData(name="ok", value="ok"), + ], + ) + ) + + send.reset_mock() + auth_client.send_voice_assistant_event( + VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR, None + ) + send.assert_called_once_with( + VoiceAssistantEventResponse( + event_type=VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR.value, + data=[], + ) + ) + + +@pytest.mark.asyncio +async def test_subscribe_voice_assistant( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + return 42 + + async def handle_stop() -> None: + stops.append(True) + + unsub = await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + assert stops == [] + send.assert_called_once_with(VoiceAssistantResponse(port=42)) + send.reset_mock() + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=False, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert stops == [True] + send.reset_mock() + unsub() + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) + send.reset_mock() + await client.disconnect(force=True) + # Ensure abort callback is a no-op after disconnect + # and does not raise + unsub() + assert len(send.mock_calls) == 0 + + +@pytest.mark.asyncio +async def test_subscribe_voice_assistant_failure( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant failure.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + # Return None to indicate failure + return None + + async def handle_stop() -> None: + stops.append(True) + + unsub = await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + assert stops == [] + send.assert_called_once_with(VoiceAssistantResponse(error=True)) + send.reset_mock() + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=False, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert stops == [True] + send.reset_mock() + unsub() + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) + send.reset_mock() + await client.disconnect(force=True) + # Ensure abort callback is a no-op after disconnect + # and does not raise + unsub() + assert len(send.mock_calls) == 0 + + +@pytest.mark.asyncio +async def test_subscribe_voice_assistant_cancels_long_running_handle_start( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant cancels long running tasks on unsub.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + await asyncio.sleep(10) + # Return None to indicate failure + starts.append("never") + return None + + async def handle_stop() -> None: + stops.append(True) + + unsub = await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + unsub() + await asyncio.sleep(0) + assert not stops + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + + +@pytest.mark.asyncio +async def test_api_version_after_connection_closed( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test api version is None after connection close.""" + client, connection, transport, protocol = api_client + assert client.api_version == APIVersion(1, 9) + await client.disconnect(force=True) + assert client.api_version is None diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index f03610c..f446d8a 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +from datetime import timedelta +from functools import partial from unittest.mock import MagicMock, patch import pytest @@ -8,18 +10,22 @@ from google.protobuf import message from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore -from aioesphomeapi.api_pb2 import DisconnectResponse +from aioesphomeapi.api_pb2 import DisconnectRequest, DisconnectResponse from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection +from aioesphomeapi.core import APIConnectionError from aioesphomeapi.log_runner import async_run +from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN from .common import ( Estr, + async_fire_time_changed, generate_plaintext_packet, get_mock_async_zeroconf, mock_data_received, send_plaintext_connect_response, send_plaintext_hello, + utcnow, ) @@ -83,3 +89,156 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec disconnect_response = DisconnectResponse() mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) await stop_task + + +@pytest.mark.asyncio +async def test_log_runner_reconnects_on_disconnect( + event_loop: asyncio.AbstractEventLoop, + conn: APIConnection, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the log runner reconnects on disconnect.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + async_zeroconf = get_mock_async_zeroconf() + + cli = PatchableAPIClient( + address=Estr("1.2.3.4"), + port=6052, + password=None, + noise_psk=None, + expected_name=Estr("fake"), + zeroconf_instance=async_zeroconf.zeroconf, + ) + messages = [] + + def on_log(msg: SubscribeLogsResponse) -> None: + messages.append(msg) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + subscribed = asyncio.Event() + original_subscribe_logs = cli.subscribe_logs + + async def _wait_subscribe_cli(*args, **kwargs): + await original_subscribe_logs(*args, **kwargs) + subscribed.set() + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): + stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) + await connected.wait() + protocol = cli._connection._frame_helper + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + await subscribed.wait() + + response: message.Message = SubscribeLogsResponse() + response.message = b"Hello world" + mock_data_received(protocol, generate_plaintext_packet(response)) + assert len(messages) == 1 + assert messages[0].message == b"Hello world" + + with patch.object(cli, "start_connection") as mock_start_connection: + response: message.Message = DisconnectRequest() + mock_data_received(protocol, generate_plaintext_packet(response)) + + await asyncio.sleep(0) + assert cli._connection is None + async_fire_time_changed( + utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN) + ) + await asyncio.sleep(0) + + assert "Disconnected from API" in caplog.text + assert mock_start_connection.called + + await stop() + + +@pytest.mark.asyncio +async def test_log_runner_reconnects_on_subscribe_failure( + event_loop: asyncio.AbstractEventLoop, + conn: APIConnection, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the log runner reconnects on subscribe failure.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + async_zeroconf = get_mock_async_zeroconf() + + cli = PatchableAPIClient( + address=Estr("1.2.3.4"), + port=6052, + password=None, + noise_psk=None, + expected_name=Estr("fake"), + zeroconf_instance=async_zeroconf.zeroconf, + ) + messages = [] + + def on_log(msg: SubscribeLogsResponse) -> None: + messages.append(msg) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + subscribed = asyncio.Event() + + async def _wait_and_fail_subscribe_cli(*args, **kwargs): + subscribed.set() + raise APIConnectionError("subscribed force to fail") + + with patch.object( + cli, "disconnect", partial(cli.disconnect, force=True) + ), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli): + with patch.object(loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) + await connected.wait() + protocol = cli._connection._frame_helper + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await subscribed.wait() + + assert cli._connection is None + + with patch.object(loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), patch.object(cli, "subscribe_logs"): + connected.clear() + await asyncio.sleep(0) + async_fire_time_changed( + utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN) + ) + await asyncio.sleep(0) + + stop_task = asyncio.create_task(stop()) + await asyncio.sleep(0) + disconnect_response = DisconnectResponse() + mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) + await stop_task diff --git a/tests/test_model.py b/tests/test_model.py index cf9df81..d81a75a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -49,6 +49,7 @@ from aioesphomeapi.model import ( APIVersion, BinarySensorInfo, BinarySensorState, + BluetoothProxyFeature, ButtonInfo, CameraInfo, ClimateInfo, @@ -61,6 +62,7 @@ from aioesphomeapi.model import ( FanState, HomeassistantServiceCall, LegacyCoverState, + LightColorCapability, LightInfo, LightState, LockEntityState, @@ -357,3 +359,141 @@ def test_user_service_conversion(): def test_build_unique_id(model): obj = model(object_id="id") assert build_unique_id("mac", obj) == f"mac-{_TYPE_TO_NAME[type(obj)]}-id" + + +@pytest.mark.parametrize( + ("version", "flags"), + [ + (1, BluetoothProxyFeature.PASSIVE_SCAN), + ( + 2, + BluetoothProxyFeature.PASSIVE_SCAN + | BluetoothProxyFeature.ACTIVE_CONNECTIONS, + ), + ( + 3, + BluetoothProxyFeature.PASSIVE_SCAN + | BluetoothProxyFeature.ACTIVE_CONNECTIONS + | BluetoothProxyFeature.REMOTE_CACHING, + ), + ( + 4, + BluetoothProxyFeature.PASSIVE_SCAN + | BluetoothProxyFeature.ACTIVE_CONNECTIONS + | BluetoothProxyFeature.REMOTE_CACHING + | BluetoothProxyFeature.PAIRING, + ), + ( + 5, + BluetoothProxyFeature.PASSIVE_SCAN + | BluetoothProxyFeature.ACTIVE_CONNECTIONS + | BluetoothProxyFeature.REMOTE_CACHING + | BluetoothProxyFeature.PAIRING + | BluetoothProxyFeature.CACHE_CLEARING, + ), + ], +) +def test_bluetooth_backcompat_for_device_info( + version: int, flags: BluetoothProxyFeature +) -> None: + info = DeviceInfo( + legacy_bluetooth_proxy_version=version, bluetooth_proxy_feature_flags=42 + ) + assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 8)) is flags + assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 9)) == 42 + + +@pytest.mark.parametrize( + ( + "legacy_supports_brightness", + "legacy_supports_rgb", + "legacy_supports_white_value", + "legacy_supports_color_temperature", + "capability", + ), + [ + (False, False, False, False, [LightColorCapability.ON_OFF]), + ( + True, + False, + False, + False, + [LightColorCapability.ON_OFF | LightColorCapability.BRIGHTNESS], + ), + ( + True, + False, + False, + True, + [ + LightColorCapability.ON_OFF + | LightColorCapability.BRIGHTNESS + | LightColorCapability.COLOR_TEMPERATURE + ], + ), + ( + True, + True, + False, + False, + [ + LightColorCapability.ON_OFF + | LightColorCapability.BRIGHTNESS + | LightColorCapability.RGB + ], + ), + ( + True, + True, + True, + False, + [ + LightColorCapability.ON_OFF + | LightColorCapability.BRIGHTNESS + | LightColorCapability.RGB + | LightColorCapability.WHITE + ], + ), + ( + True, + True, + False, + True, + [ + LightColorCapability.ON_OFF + | LightColorCapability.BRIGHTNESS + | LightColorCapability.RGB + | LightColorCapability.COLOR_TEMPERATURE + ], + ), + ( + True, + True, + True, + True, + [ + LightColorCapability.ON_OFF + | LightColorCapability.BRIGHTNESS + | LightColorCapability.RGB + | LightColorCapability.WHITE + | LightColorCapability.COLOR_TEMPERATURE + ], + ), + ], +) +def test_supported_color_modes_compat( + legacy_supports_brightness: bool, + legacy_supports_rgb: bool, + legacy_supports_white_value: bool, + legacy_supports_color_temperature: bool, + capability: list[LightColorCapability], +) -> None: + info = LightInfo( + legacy_supports_brightness=legacy_supports_brightness, + legacy_supports_rgb=legacy_supports_rgb, + legacy_supports_white_value=legacy_supports_white_value, + legacy_supports_color_temperature=legacy_supports_color_temperature, + supported_color_modes=[42], + ) + assert info.supported_color_modes_compat(APIVersion(1, 5)) == capability + assert info.supported_color_modes_compat(APIVersion(1, 9)) == [42]