Merge branch 'main' into climate_enhancements

This commit is contained in:
J. Nick Koston 2023-11-28 12:22:48 -06:00 committed by GitHub
commit fe81128146
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 90 deletions

View File

@ -5,7 +5,7 @@ import asyncio
import logging
from collections.abc import Awaitable, Coroutine
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Union
from google.protobuf import message
@ -76,8 +76,9 @@ from .client_callbacks import (
on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response,
on_bluetooth_gatt_notify_data_response,
on_bluetooth_handle_message,
on_bluetooth_le_advertising_response,
on_bluetooth_message,
on_bluetooth_message_types,
on_home_assistant_service_response,
on_state_msg,
on_subscribe_home_assistant_state_response,
@ -157,6 +158,20 @@ LIST_ENTITIES_MSG_TYPES = (
*LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
)
USER_SERVICE_MAP_ARRAY = {
UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
USER_SERVICE_MAP_SINGLE = {
# Int is a special case because it is handled
# differently depending on the APIVersion
UserServiceArgType.BOOL: "bool_",
UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: "string_",
}
ExecuteServiceDataType = dict[
str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]]
@ -191,7 +206,6 @@ class APIClient:
"cached_name",
"_background_tasks",
"_loop",
"_on_stop_task",
"log_name",
)
@ -238,7 +252,6 @@ class APIClient:
self.cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
self._on_stop_task: asyncio.Task[None] | None = None
self._set_log_name()
def set_debug(self, enabled: bool) -> None:
@ -299,20 +312,7 @@ class APIClient:
# Hook into on_stop handler to clear connection when stopped
self._connection = None
if on_stop:
self._on_stop_task = asyncio.create_task(
on_stop(expected_disconnect),
name=f"{self.log_name} aioesphomeapi on_stop",
)
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
"""Remove the stop task.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._on_stop_task = None
self._create_background_task(on_stop(expected_disconnect))
async def start_connection(
self,
@ -321,19 +321,13 @@ class APIClient:
"""Start connecting to the device."""
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self.log_name}!")
self._connection = APIConnection(
self._params,
partial(self._on_stop, on_stop),
self._debug_enabled,
self.log_name,
)
try:
await self._connection.start_connection()
except Exception:
self._connection = None
raise
await self._execute_connection_coro(self._connection.start_connection())
# If we resolved the address, we should set the log name now
if self._connection.resolved_addr_info:
self._set_log_name()
@ -345,14 +339,20 @@ class APIClient:
"""Finish connecting to the device."""
if TYPE_CHECKING:
assert self._connection is not None
try:
await self._connection.finish_connection(login=login)
except Exception:
self._connection = None
raise
await self._execute_connection_coro(
self._connection.finish_connection(login=login)
)
if received_name := self._connection.received_name:
self._set_name_from_device(received_name)
async def _execute_connection_coro(self, coro: Awaitable[None]) -> None:
"""Execute a coroutine and reset the _connection if it fails."""
try:
await coro
except Exception: # pylint: disable=broad-except
self._connection = None
raise
async def disconnect(self, force: bool = False) -> None:
if self._connection is None:
return
@ -451,7 +451,7 @@ class APIClient:
),
timeout: float = 10.0,
) -> message.Message:
message_filter = partial(on_bluetooth_message, address, handle)
message_filter = partial(on_bluetooth_handle_message, address, handle)
msg_types = (response_type, BluetoothGATTErrorResponse)
[resp] = await self._get_connection().send_messages_await_response_complex(
(request,),
@ -670,11 +670,12 @@ class APIClient:
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
types_with_response = (BluetoothDeviceConnectionResponse, *msg_types)
response = await self._bluetooth_device_request(
address,
request_type,
lambda msg: msg.address == address,
(BluetoothDeviceConnectionResponse, *msg_types),
partial(on_bluetooth_message_types, address, types_with_response),
types_with_response,
timeout,
)
self._raise_for_ble_connection_change(address, response, msg_types)
@ -706,13 +707,9 @@ class APIClient:
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest and wait for a response."""
req = BluetoothDeviceRequest(address=address, request_type=request_type)
[response] = await self._get_connection().send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address,
request_type=request_type,
),
),
(req,),
predicate_func,
predicate_func,
msg_types,
@ -723,42 +720,18 @@ class APIClient:
async def bluetooth_gatt_get_services(
self, address: int
) -> ESPHomeBluetoothGATTServices:
append_types = (
BluetoothDeviceConnectionResponse,
BluetoothGATTGetServicesResponse,
BluetoothGATTErrorResponse,
)
stop_types = (
BluetoothDeviceConnectionResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
error_types = (BluetoothGATTErrorResponse, BluetoothDeviceConnectionResponse)
append_types = (*error_types, BluetoothGATTGetServicesResponse)
stop_types = (*error_types, BluetoothGATTGetServicesDoneResponse)
msg_types = (
BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
def do_append(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in append_types and msg.address == address
def do_stop(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in stop_types and msg.address == address
resp = await self._get_connection().send_messages_await_response_complex(
(BluetoothGATTGetServicesRequest(address=address),),
do_append,
do_stop,
partial(on_bluetooth_message_types, address, append_types),
partial(on_bluetooth_message_types, address, stop_types),
(*msg_types, BluetoothDeviceConnectionResponse),
DEFAULT_BLE_TIMEOUT,
)
@ -951,7 +924,9 @@ class APIClient:
stop: bool = False,
) -> None:
req = CoverCommandRequest(key=key)
apiv = cast(APIVersion, self.api_version)
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
if apiv >= APIVersion(1, 1):
if position is not None:
req.has_position = True
@ -1098,7 +1073,9 @@ class APIClient:
req.has_custom_fan_mode = True
req.custom_fan_mode = custom_fan_mode
if preset is not None:
apiv = cast(APIVersion, self.api_version)
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
if apiv < APIVersion(1, 5):
req.has_legacy_away = True
req.legacy_away = preset == ClimatePreset.AWAY
@ -1187,26 +1164,21 @@ class APIClient:
) -> None:
req = ExecuteServiceRequest(key=service.key)
args = []
apiv = self.api_version
if TYPE_CHECKING:
assert apiv is not None
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = USER_SERVICE_MAP_SINGLE
map_array = USER_SERVICE_MAP_ARRAY
for arg_desc in service.args:
arg = ExecuteServiceArgument()
val = data[arg_desc.name]
apiv = cast(APIVersion, self.api_version)
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = {
UserServiceArgType.BOOL: "bool_",
UserServiceArgType.INT: int_type,
UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: "string_",
}
map_array = {
UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
if arg_desc.type in map_array:
attr = getattr(arg, map_array[arg_desc.type])
attr.extend(val)
elif arg_desc.type == UserServiceArgType.INT:
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
setattr(arg, int_type, val)
else:
assert arg_desc.type in map_single
setattr(arg, map_single[arg_desc.type], val)
@ -1280,9 +1252,7 @@ class APIClient:
# We hold a reference to the start_task in unsub function
# so we don't need to add it to the background tasks.
else:
stop_task = asyncio.create_task(handle_stop())
self._background_tasks.add(stop_task)
stop_task.add_done_callback(self._background_tasks.discard)
self._create_background_task(handle_stop())
connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True))
@ -1304,6 +1274,12 @@ class APIClient:
return unsub
def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None:
"""Create a background task and add it to the background tasks set."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
def send_voice_assistant_event(
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:

View File

@ -10,6 +10,8 @@ from .api_pb2 import ( # type: ignore
BluetoothConnectionsFreeResponse,
BluetoothDeviceConnectionResponse,
BluetoothGATTErrorResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTGetServicesResponse,
BluetoothGATTNotifyDataResponse,
BluetoothGATTNotifyResponse,
BluetoothGATTReadResponse,
@ -118,7 +120,7 @@ def on_bluetooth_device_connection_response(
connect_future.set_result(None)
def on_bluetooth_message(
def on_bluetooth_handle_message(
address: int,
handle: int,
msg: BluetoothGATTErrorResponse
@ -127,7 +129,23 @@ def on_bluetooth_message(
| BluetoothGATTWriteResponse
| BluetoothDeviceConnectionResponse,
) -> bool:
"""Handle a Bluetooth message."""
"""Filter a Bluetooth message for an address and handle."""
if type(msg) is BluetoothDeviceConnectionResponse:
return bool(msg.address == address)
return bool(msg.address == address and msg.handle == handle)
def on_bluetooth_message_types(
address: int,
msg_types: tuple[type[message.Message]],
msg: BluetoothGATTErrorResponse
| BluetoothGATTNotifyResponse
| BluetoothGATTReadResponse
| BluetoothGATTWriteResponse
| BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
"""Filter Bluetooth messages of a specific type and address."""
return type(msg) in msg_types and bool(msg.address == address)

View File

@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file:
long_description = readme_file.read()
VERSION = "19.1.6"
VERSION = "19.1.7"
PROJECT_NAME = "aioesphomeapi"
PROJECT_PACKAGE_NAME = "aioesphomeapi"
PROJECT_LICENSE = "MIT"