Merge branch 'main' into climate_enhancements

This commit is contained in:
J. Nick Koston 2023-11-26 18:43:44 -06:00 committed by GitHub
commit c30d26ccc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1004 additions and 110 deletions

View File

@ -39,6 +39,6 @@ cdef class APIFrameHelper:
@cython.locals(end_of_frame_pos="unsigned int")
cdef void _remove_from_buffer(self)
cpdef write_packets(self, list packets, bint debug_enabled)
cpdef void write_packets(self, list packets, bint debug_enabled)
cdef void _write_bytes(self, bytes data, bint debug_enabled)
cdef void _write_bytes(self, object data, bint debug_enabled)

View File

@ -22,6 +22,7 @@ SOCKET_ERRORS = (
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
_int = int
_bytes = bytes
class APIFrameHelper:
@ -196,7 +197,7 @@ class APIFrameHelper:
def resume_writing(self) -> None:
"""Stub."""
def _write_bytes(self, data: bytes, debug_enabled: bool) -> None:
def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
"""Write bytes to the socket."""
if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())

View File

@ -29,7 +29,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
msg_size_high="unsigned char",
msg_size_low="unsigned char",
)
cpdef data_received(self, object data)
cpdef void data_received(self, object data)
@cython.locals(
msg=bytes,
@ -64,6 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
frame=bytes,
frame_len=cython.uint,
)
cpdef write_packets(self, list packets, bint debug_enabled)
cpdef void write_packets(self, list packets, bint debug_enabled)
cdef _error_on_incorrect_preamble(self, bytes msg)

View File

@ -140,8 +140,6 @@ class APINoiseFrameHelper(APIFrameHelper):
if (header := self._read(3)) is None:
return
preamble = header[0]
msg_size_high = header[1]
msg_size_low = header[2]
if preamble != 0x01:
self._handle_error_and_close(
ProtocolAPIError(
@ -149,6 +147,8 @@ class APINoiseFrameHelper(APIFrameHelper):
)
)
return
msg_size_high = header[1]
msg_size_low = header[2]
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
# The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not

View File

@ -5,12 +5,14 @@ from .base cimport APIFrameHelper
cdef object varuint_to_bytes
cdef bytes EMPTY_PACKET
cdef bint TYPE_CHECKING
cpdef _varuint_to_bytes(cython.int value)
cdef class APIPlaintextFrameHelper(APIFrameHelper):
cpdef data_received(self, object data)
cpdef void data_received(self, object data)
cdef void _error_on_incorrect_preamble(self, int preamble)
@ -20,4 +22,4 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
packet=tuple,
type_=object
)
cpdef write_packets(self, list packets, bint debug_enabled)
cpdef void write_packets(self, list packets, bint debug_enabled)

View File

@ -2,12 +2,15 @@ from __future__ import annotations
import asyncio
from functools import lru_cache
from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
from .base import APIFrameHelper
_int = int
EMPTY_PACKET = b""
def _varuint_to_bytes(value: _int) -> bytes:
"""Convert a varuint to bytes."""
@ -71,17 +74,19 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if (msg_type := self._read_varuint()) == -1:
return
packet_data: bytes | None
if length == 0:
packet_data = b""
packet_data = EMPTY_PACKET
else:
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if (maybe_packet_data := self._read(length)) is None:
if (packet_data := self._read(length)) is None:
return
packet_data = maybe_packet_data
if TYPE_CHECKING:
assert packet_data is not None, "Packet data should be set"
self._remove_from_buffer()
self._connection.process_packet(msg_type, packet_data)

View File

@ -65,15 +65,16 @@ from .api_pb2 import ( # type: ignore
SwitchCommandRequest,
TextCommandRequest,
UnsubscribeBluetoothLEAdvertisementsRequest,
VoiceAssistantAudioSettings,
VoiceAssistantEventData,
VoiceAssistantEventResponse,
VoiceAssistantRequest,
VoiceAssistantResponse,
)
from .client_callbacks import (
handle_timeout,
on_ble_raw_advertisement_response,
on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response,
on_bluetooth_gatt_notify_data_response,
on_bluetooth_le_advertising_response,
on_home_assistant_service_response,
@ -116,9 +117,9 @@ from .model import (
MediaPlayerCommand,
UserService,
UserServiceArgType,
VoiceAssistantCommand,
VoiceAssistantEventType,
)
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
from .model import VoiceAssistantCommand, VoiceAssistantEventType
from .model_conversions import (
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
SUBSCRIBE_STATES_RESPONSE_TYPES,
@ -467,7 +468,7 @@ class APIClient:
timeout: float = 10.0,
) -> message.Message:
message_filter = partial(self._filter_bluetooth_message, address, handle)
resp = await self._get_connection().send_messages_await_response_complex(
[resp] = await self._get_connection().send_messages_await_response_complex(
(request,),
message_filter,
message_filter,
@ -475,10 +476,21 @@ class APIClient:
timeout,
)
if isinstance(resp[0], BluetoothGATTErrorResponse):
raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp[0]))
if (
type(resp) # pylint: disable=unidiomatic-typecheck
is BluetoothGATTErrorResponse
):
raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp))
return resp[0]
return resp
def _unsub_bluetooth_advertisements(
self, unsub_callback: Callable[[], None]
) -> None:
"""Unsubscribe Bluetooth advertisements if connected."""
if self._connection is not None:
unsub_callback()
self._connection.send_message(UnsubscribeBluetoothLEAdvertisementsRequest())
async def subscribe_bluetooth_le_advertisements(
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
@ -491,15 +503,7 @@ class APIClient:
),
(BluetoothLEAdvertisementResponse,),
)
def unsub() -> None:
if self._connection is not None:
unsub_callback()
self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest()
)
return unsub
return partial(self._unsub_bluetooth_advertisements, unsub_callback)
async def subscribe_bluetooth_le_raw_advertisements(
self, on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None]
@ -511,15 +515,7 @@ class APIClient:
partial(on_ble_raw_advertisement_response, on_advertisements),
(BluetoothLERawAdvertisementsResponse,),
)
def unsub() -> None:
if self._connection is not None:
unsub_callback()
self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest()
)
return unsub
return partial(self._unsub_bluetooth_advertisements, unsub_callback)
async def subscribe_bluetooth_connections_free(
self, on_bluetooth_connections_free_update: Callable[[int, int], None]
@ -533,28 +529,6 @@ class APIClient:
(BluetoothConnectionsFreeResponse,),
)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if fut.done():
return
fut.set_exception(asyncio.TimeoutError)
def _on_bluetooth_device_connection_response(
self,
connect_future: asyncio.Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
if address == msg.address:
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)
async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches
self,
address: int,
@ -589,7 +563,7 @@ class APIClient:
address_type=address_type or 0,
),
partial(
self._on_bluetooth_device_connection_response,
on_bluetooth_device_connection_response,
connect_future,
address,
on_bluetooth_connection_state,
@ -599,7 +573,7 @@ class APIClient:
loop = self._loop
timeout_handle = loop.call_at(
loop.time() + timeout, self._handle_timeout, connect_future
loop.time() + timeout, handle_timeout, connect_future
)
timeout_expired = False
connect_ok = False
@ -607,6 +581,11 @@ class APIClient:
await connect_future
connect_ok = True
except asyncio.TimeoutError as err:
# If the timeout expires, make sure
# to unsub before calling _bluetooth_device_disconnect_guard_timeout
# so that the disconnect message is not propagated back to the caller
# since we are going to raise a TimeoutAPIError.
unsub()
timeout_expired = True
# Disconnect before raising the exception to ensure
# the slot is recovered before the timeout is raised
@ -625,14 +604,8 @@ class APIClient:
f" after {disconnect_timeout}s"
) from err
finally:
if not connect_ok:
try:
unsub()
except (KeyError, ValueError):
_LOGGER.warning(
"%s: Bluetooth device connection canceled but already unsubscribed",
to_human_readable_address(address),
)
if not connect_ok and not timeout_expired:
unsub()
if not timeout_expired:
timeout_handle.cancel()
@ -667,7 +640,7 @@ class APIClient:
return False
if isinstance(msg, BluetoothDeviceConnectionResponse):
raise APIConnectionError(
"Peripheral changed connections status while pairing"
f"Peripheral changed connections status while pairing: {msg.error}"
)
return True
@ -1274,7 +1247,8 @@ class APIClient:
async def subscribe_voice_assistant(
self,
handle_start: Callable[
[str, int, VoiceAssistantAudioSettings], Coroutine[Any, Any, int | None]
[str, int, VoiceAssistantAudioSettingsModel],
Coroutine[Any, Any, int | None],
],
handle_stop: Callable[[], Coroutine[Any, Any, None]],
) -> Callable[[], None]:
@ -1301,6 +1275,8 @@ class APIClient:
self._connection.send_message(VoiceAssistantResponse(error=True))
def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None:
nonlocal start_task
command = VoiceAssistantCommand.from_pb(msg)
if command.start:
start_task = asyncio.create_task(
@ -1323,6 +1299,8 @@ class APIClient:
)
def unsub() -> None:
nonlocal start_task
if self._connection is not None:
remove_callback()
self._connection.send_message(
@ -1337,20 +1315,15 @@ class APIClient:
def send_voice_assistant_event(
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
req = VoiceAssistantEventResponse()
req.event_type = event_type
data_args = []
req = VoiceAssistantEventResponse(event_type=event_type)
if data is not None:
for name, value in data.items():
arg = VoiceAssistantEventData()
arg.name = name
arg.value = value
data_args.append(arg)
# pylint: disable=no-member
req.data.extend(data_args)
# pylint: disable=no-member
req.data.extend(
[
VoiceAssistantEventData(name=name, value=value)
for name, value in data.items()
]
)
self._get_connection().send_message(req)
async def alarm_control_panel_command(

View File

@ -8,3 +8,5 @@ cdef object CameraImageResponse, CameraState
cdef object HomeassistantServiceCall
cdef object BluetoothLEAdvertisement
cdef object asyncio_TimeoutError

View File

@ -1,11 +1,14 @@
from __future__ import annotations
from asyncio import Future
from asyncio import TimeoutError as asyncio_TimeoutError
from typing import TYPE_CHECKING, Callable
from google.protobuf import message
from .api_pb2 import ( # type: ignore
BluetoothConnectionsFreeResponse,
BluetoothDeviceConnectionResponse,
BluetoothGATTNotifyDataResponse,
BluetoothLEAdvertisementResponse,
BluetoothLERawAdvertisement,
@ -93,3 +96,25 @@ def on_subscribe_home_assistant_state_response(
msg: SubscribeHomeAssistantStateResponse,
) -> None:
on_state_sub(msg.entity_id, msg.attribute)
def handle_timeout(fut: Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def on_bluetooth_device_connection_response(
connect_future: Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
if address == msg.address:
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)

View File

@ -78,12 +78,13 @@ class APIModelBase:
def from_dict(
cls: type[_V], data: dict[str, Any], *, ignore_missing: bool = True
) -> _V:
init_args = {
f.name: data[f.name]
for f in cached_fields(cls) # type: ignore[arg-type]
if f.name in data or (not ignore_missing)
}
return cls(**init_args)
return cls(
**{
f.name: data[f.name]
for f in cached_fields(cls) # type: ignore[arg-type]
if f.name in data or (not ignore_missing)
}
)
@classmethod
def from_pb(cls: type[_V], data: Any) -> _V:

View File

@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file:
long_description = readme_file.read()
VERSION = "19.1.0"
VERSION = "19.1.1"
PROJECT_NAME = "aioesphomeapi"
PROJECT_PACKAGE_NAME = "aioesphomeapi"
PROJECT_LICENSE = "MIT"

View File

@ -17,7 +17,12 @@ from aioesphomeapi.connection import APIConnection
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from aioesphomeapi.zeroconf import ZeroconfManager
from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
from .common import (
connect,
connect_client,
get_mock_async_zeroconf,
send_plaintext_hello,
)
KEEP_ALIVE_INTERVAL = 15.0
@ -80,19 +85,19 @@ def connection_params() -> ConnectionParams:
return get_mock_connection_params()
def on_stop(expected_disconnect: bool) -> None:
def mock_on_stop(expected_disconnect: bool) -> None:
pass
@pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(connection_params, on_stop, True, None)
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
def conn_with_password(connection_params: ConnectionParams) -> APIConnection:
connection_params = replace(connection_params, password="password")
return PatchableAPIConnection(connection_params, on_stop, True, None)
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
@ -100,13 +105,13 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
connection_params = replace(
connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
)
return PatchableAPIConnection(connection_params, on_stop, True, None)
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
@pytest.fixture
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
connection_params = replace(connection_params, expected_name="test")
return PatchableAPIConnection(connection_params, on_stop, True, None)
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
def _create_mock_transport_protocol(
@ -177,7 +182,7 @@ async def plaintext_connect_task_with_login(
@pytest_asyncio.fixture(name="api_client")
async def api_client(
conn: APIConnection, resolve_host, socket_socket, event_loop
resolve_host, socket_socket, event_loop
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
@ -192,12 +197,12 @@ async def api_client(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
), patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
connect_task = asyncio.create_task(connect_client(client, login=False))
await connected.wait()
conn = client._connection
protocol = conn._frame_helper
send_plaintext_hello(protocol)
client._connection = conn
await connect_task
transport.reset_mock()
yield client, conn, transport, protocol

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import itertools
import logging
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
@ -17,7 +18,10 @@ from aioesphomeapi.api_pb2 import (
BluetoothDeviceClearCacheResponse,
BluetoothDeviceConnectionResponse,
BluetoothDevicePairingResponse,
BluetoothDeviceRequest,
BluetoothDeviceUnpairingResponse,
BluetoothGATTCharacteristic,
BluetoothGATTDescriptor,
BluetoothGATTErrorResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTGetServicesResponse,
@ -29,6 +33,7 @@ from aioesphomeapi.api_pb2 import (
BluetoothLEAdvertisementResponse,
BluetoothLERawAdvertisement,
BluetoothLERawAdvertisementsResponse,
BluetoothServiceData,
ButtonCommandRequest,
CameraImageRequest,
CameraImageResponse,
@ -40,6 +45,7 @@ from aioesphomeapi.api_pb2 import (
ExecuteServiceRequest,
FanCommandRequest,
HomeassistantServiceResponse,
HomeAssistantStateResponse,
LightCommandRequest,
ListEntitiesBinarySensorResponse,
ListEntitiesDoneResponse,
@ -51,8 +57,14 @@ from aioesphomeapi.api_pb2 import (
SirenCommandRequest,
SubscribeHomeAssistantStateResponse,
SubscribeLogsResponse,
SubscribeVoiceAssistantRequest,
SwitchCommandRequest,
TextCommandRequest,
VoiceAssistantAudioSettings,
VoiceAssistantEventData,
VoiceAssistantEventResponse,
VoiceAssistantRequest,
VoiceAssistantResponse,
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
@ -67,10 +79,12 @@ from aioesphomeapi.model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
BluetoothDeviceRequestType,
)
from aioesphomeapi.model import BluetoothGATTService as BluetoothGATTServiceModel
from aioesphomeapi.model import (
BluetoothLEAdvertisement,
BluetoothProxyFeature,
CameraState,
ClimateFanMode,
ClimateMode,
@ -88,6 +102,10 @@ from aioesphomeapi.model import (
UserServiceArg,
UserServiceArgType,
)
from aioesphomeapi.model import (
VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel,
)
from aioesphomeapi.model import VoiceAssistantEventType as VoiceAssistantEventModelType
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import (
@ -194,6 +212,29 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
await cli.finish_connection(False)
@pytest.mark.asyncio
async def test_request_while_handshaking(event_loop) -> None:
"""Test trying a request while handshaking raises."""
class PatchableApiClient(APIClient):
pass
cli = PatchableApiClient("host", 1234, None)
with patch.object(
event_loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)
), patch.object(cli, "finish_connection"):
connect_task = asyncio.create_task(cli.connect())
await asyncio.sleep(0)
with pytest.raises(
APIConnectionError, match="Authenticated connection not ready yet"
):
await cli.device_info()
connect_task.cancel()
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_connect_while_already_connected(auth_client: APIClient) -> None:
"""Test connecting while already connected raises."""
@ -899,11 +940,36 @@ async def test_bluetooth_pair(
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDevicePairingResponse(address=4567)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert not pair_task.done()
response: message.Message = BluetoothDevicePairingResponse(address=1234)
mock_data_received(protocol, generate_plaintext_packet(response))
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_pair_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_device_pair."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(
APIConnectionError,
match="Peripheral changed connections status while pairing: 13",
):
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_unpair(
api_client: tuple[
@ -942,7 +1008,7 @@ async def test_device_info(
) -> None:
"""Test fetching device info."""
client, connection, transport, protocol = api_client
assert client.log_name == "mydevice.local"
assert client.log_name == "fake @ 10.0.0.512"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
response: message.Message = DeviceInfoResponse(
@ -961,7 +1027,7 @@ async def test_device_info(
response: message.Message = DisconnectResponse()
mock_data_received(protocol, generate_plaintext_packet(response))
await disconnect_task
with pytest.raises(APIConnectionError, match="CLOSED"):
with pytest.raises(APIConnectionError, match="Not connected"):
await client.device_info()
@ -988,6 +1054,24 @@ async def test_bluetooth_gatt_read(
assert await read_task == b"1234"
@pytest.mark.asyncio
async def test_bluetooth_gatt_read_error(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_gatt_read that errors."""
client, connection, transport, protocol = api_client
read_task = asyncio.create_task(client.bluetooth_gatt_read(1234, 1234))
await asyncio.sleep(0)
error_response: message.Message = BluetoothGATTErrorResponse(
address=1234, handle=1234
)
mock_data_received(protocol, generate_plaintext_packet(error_response))
with pytest.raises(BluetoothGATTAPIError):
await read_task
@pytest.mark.asyncio
async def test_bluetooth_gatt_read_descriptor(
api_client: tuple[
@ -1110,7 +1194,16 @@ async def test_bluetooth_gatt_get_services(
services_task = asyncio.create_task(client.bluetooth_gatt_get_services(1234))
await asyncio.sleep(0)
service1: message.Message = BluetoothGATTService(
uuid=[1, 1], handle=1, characteristics=[]
uuid=[1, 1],
handle=1,
characteristics=[
BluetoothGATTCharacteristic(
uuid=[1, 2],
handle=2,
properties=1,
descriptors=[BluetoothGATTDescriptor(uuid=[1, 3], handle=3)],
)
],
)
response: message.Message = BluetoothGATTGetServicesResponse(
address=1234, services=[service1]
@ -1120,9 +1213,10 @@ async def test_bluetooth_gatt_get_services(
mock_data_received(protocol, generate_plaintext_packet(done_response))
services = await services_task
service = BluetoothGATTServiceModel.from_pb(service1)
assert services == ESPHomeBluetoothGATTServices(
address=1234,
services=[BluetoothGATTServiceModel(uuid=[1, 1], handle=1, characteristics=[])],
services=[service],
)
@ -1200,6 +1294,10 @@ async def test_bluetooth_gatt_start_notify(
# Ensure abort callback is a no-op after cancel
# and doesn't raise
abort_cb()
await client.disconnect(force=True)
# Ensure abort callback is a no-op after disconnect
# and does not raise
await cancel_cb()
@pytest.mark.asyncio
@ -1254,8 +1352,18 @@ async def test_subscribe_bluetooth_le_advertisements(
name=b"mydevice",
rssi=-50,
service_uuids=["1234"],
service_data={},
manufacturer_data={},
service_data=[
BluetoothServiceData(
uuid="1234",
data=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
)
],
manufacturer_data=[
BluetoothServiceData(
uuid="1234",
data=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
)
],
address_type=1,
)
mock_data_received(protocol, generate_plaintext_packet(response))
@ -1266,6 +1374,33 @@ async def test_subscribe_bluetooth_le_advertisements(
name="mydevice",
rssi=-50,
service_uuids=["000034-0000-1000-8000-00805f9b34fb"],
manufacturer_data={
4660: b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
},
service_data={
"000034-0000-1000-8000-00805f9b34fb": b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
},
address_type=1,
)
]
advs.clear()
response: message.Message = BluetoothLEAdvertisementResponse(
address=1234,
name=b"mydevice",
rssi=-50,
service_uuids=[],
service_data=[],
manufacturer_data=[],
address_type=1,
)
mock_data_received(protocol, generate_plaintext_packet(response))
assert advs == [
BluetoothLEAdvertisement(
address=1234,
name="mydevice",
rssi=-50,
service_uuids=[],
manufacturer_data={},
service_data={},
address_type=1,
@ -1374,6 +1509,17 @@ async def test_subscribe_logs(auth_client: APIClient) -> None:
on_logs.assert_called_with(log_msg)
@pytest.mark.asyncio
async def test_send_home_assistant_state(auth_client: APIClient) -> None:
send = patch_send(auth_client)
await auth_client.send_home_assistant_state("binary_sensor.bla", None, "on")
send.assert_called_once_with(
HomeAssistantStateResponse(
entity_id="binary_sensor.bla", state="on", attribute=None
)
)
@pytest.mark.asyncio
async def test_subscribe_service_calls(auth_client: APIClient) -> None:
send = patch_response_callback(auth_client)
@ -1402,7 +1548,7 @@ async def test_set_debug(
caplog.set_level(logging.DEBUG)
client.set_debug(True)
assert client.log_name == "mydevice.local"
assert client.log_name == "fake @ 10.0.0.512"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
mock_data_received(protocol, generate_plaintext_packet(response))
@ -1423,11 +1569,446 @@ async def test_force_disconnect(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test force disconnect can be called multiple times."""
client, connection, transport, protocol = api_client
assert connection.is_connected is True
assert connection.on_stop is not None
await client.disconnect(force=True)
assert client._connection is None
assert connection.is_connected is False
await client.disconnect(force=False)
assert connection.is_connected is False
@pytest.mark.asyncio
@pytest.mark.parametrize(
("has_cache", "feature_flags", "method"),
[
(False, BluetoothProxyFeature(0), BluetoothDeviceRequestType.CONNECT),
(
False,
BluetoothProxyFeature.REMOTE_CACHING,
BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE,
),
(
True,
BluetoothProxyFeature.REMOTE_CACHING,
BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE,
),
],
)
async def test_bluetooth_device_connect(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
has_cache: bool,
feature_flags: BluetoothProxyFeature,
method: BluetoothDeviceRequestType,
) -> None:
"""Test bluetooth_device_connect."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=1,
feature_flags=feature_flags,
has_cache=has_cache,
disconnect_timeout=1,
address_type=1,
)
)
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=True, mtu=23, error=0
)
mock_data_received(protocol, generate_plaintext_packet(response))
cancel = await connect_task
assert states == [(True, 23, 0)]
transport.write.assert_called_once_with(
generate_plaintext_packet(
BluetoothDeviceRequest(
address=1234,
request_type=method,
has_address_type=True,
address_type=1,
),
)
)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=7
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert states == [(True, 23, 0), (False, 23, 7)]
cancel()
# After cancel, no more messages should called back
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert states == [(True, 23, 0), (False, 23, 7)]
@pytest.mark.asyncio
async def test_bluetooth_device_connect_and_disconnect_times_out(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect and disconnect times out."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=0,
feature_flags=0,
has_cache=True,
disconnect_timeout=0,
address_type=1,
)
)
with pytest.raises(TimeoutAPIError):
await connect_task
assert states == []
@pytest.mark.asyncio
async def test_bluetooth_device_connect_times_out_disconnect_ok(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect and disconnect times out."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=0,
feature_flags=0,
has_cache=True,
disconnect_timeout=1,
address_type=1,
)
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
# Now that we timed out, the disconnect
# request should be written
assert len(transport.write.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(TimeoutAPIError):
await connect_task
assert states == []
@pytest.mark.asyncio
async def test_bluetooth_device_connect_cancelled(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect handles cancellation."""
client, connection, transport, protocol = api_client
states = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=10,
feature_flags=0,
has_cache=True,
disconnect_timeout=10,
address_type=1,
)
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task
assert states == []
handlers_after = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
# Make sure we do not leak message handlers
assert handlers_after == handlers_before
@pytest.mark.asyncio
async def test_send_voice_assistant_event(auth_client: APIClient) -> None:
send = patch_send(auth_client)
auth_client.send_voice_assistant_event(
VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR,
{"error": "error", "ok": "ok"},
)
send.assert_called_once_with(
VoiceAssistantEventResponse(
event_type=VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR.value,
data=[
VoiceAssistantEventData(name="error", value="error"),
VoiceAssistantEventData(name="ok", value="ok"),
],
)
)
send.reset_mock()
auth_client.send_voice_assistant_event(
VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR, None
)
send.assert_called_once_with(
VoiceAssistantEventResponse(
event_type=VoiceAssistantEventModelType.VOICE_ASSISTANT_ERROR.value,
data=[],
)
)
@pytest.mark.asyncio
async def test_subscribe_voice_assistant(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test subscribe_voice_assistant."""
client, connection, transport, protocol = api_client
send = patch_send(client)
starts = []
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
return 42
async def handle_stop() -> None:
stops.append(True)
unsub = await client.subscribe_voice_assistant(handle_start, handle_stop)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock()
audio_settings = VoiceAssistantAudioSettings(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
)
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=True,
flags=42,
audio_settings=audio_settings,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
await asyncio.sleep(0)
assert starts == [
(
"theone",
42,
VoiceAssistantAudioSettingsModel(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
),
)
]
assert stops == []
send.assert_called_once_with(VoiceAssistantResponse(port=42))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=False,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
send.reset_mock()
await client.disconnect(force=True)
# Ensure abort callback is a no-op after disconnect
# and does not raise
unsub()
assert len(send.mock_calls) == 0
@pytest.mark.asyncio
async def test_subscribe_voice_assistant_failure(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test subscribe_voice_assistant failure."""
client, connection, transport, protocol = api_client
send = patch_send(client)
starts = []
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
# Return None to indicate failure
return None
async def handle_stop() -> None:
stops.append(True)
unsub = await client.subscribe_voice_assistant(handle_start, handle_stop)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock()
audio_settings = VoiceAssistantAudioSettings(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
)
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=True,
flags=42,
audio_settings=audio_settings,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
await asyncio.sleep(0)
assert starts == [
(
"theone",
42,
VoiceAssistantAudioSettingsModel(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
),
)
]
assert stops == []
send.assert_called_once_with(VoiceAssistantResponse(error=True))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=False,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
send.reset_mock()
await client.disconnect(force=True)
# Ensure abort callback is a no-op after disconnect
# and does not raise
unsub()
assert len(send.mock_calls) == 0
@pytest.mark.asyncio
async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test subscribe_voice_assistant cancels long running tasks on unsub."""
client, connection, transport, protocol = api_client
send = patch_send(client)
starts = []
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
await asyncio.sleep(10)
# Return None to indicate failure
starts.append("never")
return None
async def handle_stop() -> None:
stops.append(True)
unsub = await client.subscribe_voice_assistant(handle_start, handle_stop)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock()
audio_settings = VoiceAssistantAudioSettings(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
)
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=True,
flags=42,
audio_settings=audio_settings,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
await asyncio.sleep(0)
unsub()
await asyncio.sleep(0)
assert not stops
assert starts == [
(
"theone",
42,
VoiceAssistantAudioSettingsModel(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
),
)
]
@pytest.mark.asyncio
async def test_api_version_after_connection_closed(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test api version is None after connection close."""
client, connection, transport, protocol = api_client
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
from datetime import timedelta
from functools import partial
from unittest.mock import MagicMock, patch
import pytest
@ -8,18 +10,22 @@ from google.protobuf import message
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.api_pb2 import DisconnectResponse
from aioesphomeapi.api_pb2 import DisconnectRequest, DisconnectResponse
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.log_runner import async_run
from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN
from .common import (
Estr,
async_fire_time_changed,
generate_plaintext_packet,
get_mock_async_zeroconf,
mock_data_received,
send_plaintext_connect_response,
send_plaintext_hello,
utcnow,
)
@ -83,3 +89,156 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
disconnect_response = DisconnectResponse()
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
await stop_task
@pytest.mark.asyncio
async def test_log_runner_reconnects_on_disconnect(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the log runner reconnects on disconnect."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
original_subscribe_logs = cli.subscribe_logs
async def _wait_subscribe_cli(*args, **kwargs):
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world"
mock_data_received(protocol, generate_plaintext_packet(response))
assert len(messages) == 1
assert messages[0].message == b"Hello world"
with patch.object(cli, "start_connection") as mock_start_connection:
response: message.Message = DisconnectRequest()
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert cli._connection is None
async_fire_time_changed(
utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN)
)
await asyncio.sleep(0)
assert "Disconnected from API" in caplog.text
assert mock_start_connection.called
await stop()
@pytest.mark.asyncio
async def test_log_runner_reconnects_on_subscribe_failure(
event_loop: asyncio.AbstractEventLoop,
conn: APIConnection,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the log runner reconnects on subscribe failure."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
)
messages = []
def on_log(msg: SubscribeLogsResponse) -> None:
messages.append(msg)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
subscribed = asyncio.Event()
async def _wait_and_fail_subscribe_cli(*args, **kwargs):
subscribed.set()
raise APIConnectionError("subscribed force to fail")
with patch.object(
cli, "disconnect", partial(cli.disconnect, force=True)
), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli):
with patch.object(loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await subscribed.wait()
assert cli._connection is None
with patch.object(loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs"):
connected.clear()
await asyncio.sleep(0)
async_fire_time_changed(
utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN)
)
await asyncio.sleep(0)
stop_task = asyncio.create_task(stop())
await asyncio.sleep(0)
disconnect_response = DisconnectResponse()
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
await stop_task

View File

@ -49,6 +49,7 @@ from aioesphomeapi.model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
BluetoothProxyFeature,
ButtonInfo,
CameraInfo,
ClimateInfo,
@ -61,6 +62,7 @@ from aioesphomeapi.model import (
FanState,
HomeassistantServiceCall,
LegacyCoverState,
LightColorCapability,
LightInfo,
LightState,
LockEntityState,
@ -357,3 +359,141 @@ def test_user_service_conversion():
def test_build_unique_id(model):
obj = model(object_id="id")
assert build_unique_id("mac", obj) == f"mac-{_TYPE_TO_NAME[type(obj)]}-id"
@pytest.mark.parametrize(
("version", "flags"),
[
(1, BluetoothProxyFeature.PASSIVE_SCAN),
(
2,
BluetoothProxyFeature.PASSIVE_SCAN
| BluetoothProxyFeature.ACTIVE_CONNECTIONS,
),
(
3,
BluetoothProxyFeature.PASSIVE_SCAN
| BluetoothProxyFeature.ACTIVE_CONNECTIONS
| BluetoothProxyFeature.REMOTE_CACHING,
),
(
4,
BluetoothProxyFeature.PASSIVE_SCAN
| BluetoothProxyFeature.ACTIVE_CONNECTIONS
| BluetoothProxyFeature.REMOTE_CACHING
| BluetoothProxyFeature.PAIRING,
),
(
5,
BluetoothProxyFeature.PASSIVE_SCAN
| BluetoothProxyFeature.ACTIVE_CONNECTIONS
| BluetoothProxyFeature.REMOTE_CACHING
| BluetoothProxyFeature.PAIRING
| BluetoothProxyFeature.CACHE_CLEARING,
),
],
)
def test_bluetooth_backcompat_for_device_info(
version: int, flags: BluetoothProxyFeature
) -> None:
info = DeviceInfo(
legacy_bluetooth_proxy_version=version, bluetooth_proxy_feature_flags=42
)
assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 8)) is flags
assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 9)) == 42
@pytest.mark.parametrize(
(
"legacy_supports_brightness",
"legacy_supports_rgb",
"legacy_supports_white_value",
"legacy_supports_color_temperature",
"capability",
),
[
(False, False, False, False, [LightColorCapability.ON_OFF]),
(
True,
False,
False,
False,
[LightColorCapability.ON_OFF | LightColorCapability.BRIGHTNESS],
),
(
True,
False,
False,
True,
[
LightColorCapability.ON_OFF
| LightColorCapability.BRIGHTNESS
| LightColorCapability.COLOR_TEMPERATURE
],
),
(
True,
True,
False,
False,
[
LightColorCapability.ON_OFF
| LightColorCapability.BRIGHTNESS
| LightColorCapability.RGB
],
),
(
True,
True,
True,
False,
[
LightColorCapability.ON_OFF
| LightColorCapability.BRIGHTNESS
| LightColorCapability.RGB
| LightColorCapability.WHITE
],
),
(
True,
True,
False,
True,
[
LightColorCapability.ON_OFF
| LightColorCapability.BRIGHTNESS
| LightColorCapability.RGB
| LightColorCapability.COLOR_TEMPERATURE
],
),
(
True,
True,
True,
True,
[
LightColorCapability.ON_OFF
| LightColorCapability.BRIGHTNESS
| LightColorCapability.RGB
| LightColorCapability.WHITE
| LightColorCapability.COLOR_TEMPERATURE
],
),
],
)
def test_supported_color_modes_compat(
legacy_supports_brightness: bool,
legacy_supports_rgb: bool,
legacy_supports_white_value: bool,
legacy_supports_color_temperature: bool,
capability: list[LightColorCapability],
) -> None:
info = LightInfo(
legacy_supports_brightness=legacy_supports_brightness,
legacy_supports_rgb=legacy_supports_rgb,
legacy_supports_white_value=legacy_supports_white_value,
legacy_supports_color_temperature=legacy_supports_color_temperature,
supported_color_modes=[42],
)
assert info.supported_color_modes_compat(APIVersion(1, 5)) == capability
assert info.supported_color_modes_compat(APIVersion(1, 9)) == [42]