Refactor execute_service to avoid creating dict in the inner loop (#776)

This commit is contained in:
J. Nick Koston 2023-11-28 10:19:39 -06:00 committed by GitHub
parent e8560c1547
commit 07499907d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,7 @@ import asyncio
import logging
from collections.abc import Awaitable, Coroutine
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Union
from google.protobuf import message
@ -157,6 +157,20 @@ LIST_ENTITIES_MSG_TYPES = (
*LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
)
USER_SERVICE_MAP_ARRAY = {
UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
USER_SERVICE_MAP_SINGLE = {
# Int is a special case because it is handled
# differently depending on the APIVersion
UserServiceArgType.BOOL: "bool_",
UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: "string_",
}
ExecuteServiceDataType = dict[
str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]]
@ -951,7 +965,9 @@ class APIClient:
stop: bool = False,
) -> None:
req = CoverCommandRequest(key=key)
apiv = cast(APIVersion, self.api_version)
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
if apiv >= APIVersion(1, 1):
if position is not None:
req.has_position = True
@ -1096,7 +1112,9 @@ class APIClient:
req.has_custom_fan_mode = True
req.custom_fan_mode = custom_fan_mode
if preset is not None:
apiv = cast(APIVersion, self.api_version)
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
if apiv < APIVersion(1, 5):
req.has_legacy_away = True
req.legacy_away = preset == ClimatePreset.AWAY
@ -1179,26 +1197,21 @@ class APIClient:
) -> None:
req = ExecuteServiceRequest(key=service.key)
args = []
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = USER_SERVICE_MAP_SINGLE
map_array = USER_SERVICE_MAP_ARRAY
for arg_desc in service.args:
arg = ExecuteServiceArgument()
val = data[arg_desc.name]
apiv = cast(APIVersion, self.api_version)
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = {
UserServiceArgType.BOOL: "bool_",
UserServiceArgType.INT: int_type,
UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: "string_",
}
map_array = {
UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
if arg_desc.type in map_array:
attr = getattr(arg, map_array[arg_desc.type])
attr.extend(val)
elif arg_desc.type == UserServiceArgType.INT:
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
setattr(arg, int_type, val)
else:
assert arg_desc.type in map_single
setattr(arg, map_single[arg_desc.type], val)