diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 9ff949e..e337d29 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -5,7 +5,7 @@ import asyncio import logging from collections.abc import Awaitable, Coroutine from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Union from google.protobuf import message @@ -157,6 +157,20 @@ LIST_ENTITIES_MSG_TYPES = ( *LIST_ENTITIES_SERVICES_RESPONSE_TYPES, ) +USER_SERVICE_MAP_ARRAY = { + UserServiceArgType.BOOL_ARRAY: "bool_array", + UserServiceArgType.INT_ARRAY: "int_array", + UserServiceArgType.FLOAT_ARRAY: "float_array", + UserServiceArgType.STRING_ARRAY: "string_array", +} +USER_SERVICE_MAP_SINGLE = { + # Int is a special case because it is handled + # differently depending on the APIVersion + UserServiceArgType.BOOL: "bool_", + UserServiceArgType.FLOAT: "float_", + UserServiceArgType.STRING: "string_", +} + ExecuteServiceDataType = dict[ str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]] @@ -951,7 +965,9 @@ class APIClient: stop: bool = False, ) -> None: req = CoverCommandRequest(key=key) - apiv = cast(APIVersion, self.api_version) + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None if apiv >= APIVersion(1, 1): if position is not None: req.has_position = True @@ -1096,7 +1112,9 @@ class APIClient: req.has_custom_fan_mode = True req.custom_fan_mode = custom_fan_mode if preset is not None: - apiv = cast(APIVersion, self.api_version) + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None if apiv < APIVersion(1, 5): req.has_legacy_away = True req.legacy_away = preset == ClimatePreset.AWAY @@ -1179,26 +1197,21 @@ class APIClient: ) -> None: req = ExecuteServiceRequest(key=service.key) args = [] + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None + int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" + map_single = USER_SERVICE_MAP_SINGLE + map_array = USER_SERVICE_MAP_ARRAY for arg_desc in service.args: arg = ExecuteServiceArgument() val = data[arg_desc.name] - apiv = cast(APIVersion, self.api_version) - int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" - map_single = { - UserServiceArgType.BOOL: "bool_", - UserServiceArgType.INT: int_type, - UserServiceArgType.FLOAT: "float_", - UserServiceArgType.STRING: "string_", - } - map_array = { - UserServiceArgType.BOOL_ARRAY: "bool_array", - UserServiceArgType.INT_ARRAY: "int_array", - UserServiceArgType.FLOAT_ARRAY: "float_array", - UserServiceArgType.STRING_ARRAY: "string_array", - } if arg_desc.type in map_array: attr = getattr(arg, map_array[arg_desc.type]) attr.extend(val) + elif arg_desc.type == UserServiceArgType.INT: + int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" + setattr(arg, int_type, val) else: assert arg_desc.type in map_single setattr(arg, map_single[arg_desc.type], val)