mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Merge branch 'main' into feature/fan_presets
This commit is contained in:
commit
3c675b6847
@ -5,7 +5,7 @@ import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable, Coroutine
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Union
|
||||
|
||||
from google.protobuf import message
|
||||
|
||||
@ -76,8 +76,9 @@ from .client_callbacks import (
|
||||
on_bluetooth_connections_free_response,
|
||||
on_bluetooth_device_connection_response,
|
||||
on_bluetooth_gatt_notify_data_response,
|
||||
on_bluetooth_handle_message,
|
||||
on_bluetooth_le_advertising_response,
|
||||
on_bluetooth_message,
|
||||
on_bluetooth_message_types,
|
||||
on_home_assistant_service_response,
|
||||
on_state_msg,
|
||||
on_subscribe_home_assistant_state_response,
|
||||
@ -157,6 +158,20 @@ LIST_ENTITIES_MSG_TYPES = (
|
||||
*LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
|
||||
)
|
||||
|
||||
USER_SERVICE_MAP_ARRAY = {
|
||||
UserServiceArgType.BOOL_ARRAY: "bool_array",
|
||||
UserServiceArgType.INT_ARRAY: "int_array",
|
||||
UserServiceArgType.FLOAT_ARRAY: "float_array",
|
||||
UserServiceArgType.STRING_ARRAY: "string_array",
|
||||
}
|
||||
USER_SERVICE_MAP_SINGLE = {
|
||||
# Int is a special case because it is handled
|
||||
# differently depending on the APIVersion
|
||||
UserServiceArgType.BOOL: "bool_",
|
||||
UserServiceArgType.FLOAT: "float_",
|
||||
UserServiceArgType.STRING: "string_",
|
||||
}
|
||||
|
||||
|
||||
ExecuteServiceDataType = dict[
|
||||
str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]]
|
||||
@ -191,7 +206,6 @@ class APIClient:
|
||||
"cached_name",
|
||||
"_background_tasks",
|
||||
"_loop",
|
||||
"_on_stop_task",
|
||||
"log_name",
|
||||
)
|
||||
|
||||
@ -238,7 +252,6 @@ class APIClient:
|
||||
self.cached_name: str | None = None
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._on_stop_task: asyncio.Task[None] | None = None
|
||||
self._set_log_name()
|
||||
|
||||
def set_debug(self, enabled: bool) -> None:
|
||||
@ -299,20 +312,7 @@ class APIClient:
|
||||
# Hook into on_stop handler to clear connection when stopped
|
||||
self._connection = None
|
||||
if on_stop:
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
on_stop(expected_disconnect),
|
||||
name=f"{self.log_name} aioesphomeapi on_stop",
|
||||
)
|
||||
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
|
||||
|
||||
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task.
|
||||
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._on_stop_task = None
|
||||
self._create_background_task(on_stop(expected_disconnect))
|
||||
|
||||
async def start_connection(
|
||||
self,
|
||||
@ -321,19 +321,13 @@ class APIClient:
|
||||
"""Start connecting to the device."""
|
||||
if self._connection is not None:
|
||||
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
||||
|
||||
self._connection = APIConnection(
|
||||
self._params,
|
||||
partial(self._on_stop, on_stop),
|
||||
self._debug_enabled,
|
||||
self.log_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._connection.start_connection()
|
||||
except Exception:
|
||||
self._connection = None
|
||||
raise
|
||||
await self._execute_connection_coro(self._connection.start_connection())
|
||||
# If we resolved the address, we should set the log name now
|
||||
if self._connection.resolved_addr_info:
|
||||
self._set_log_name()
|
||||
@ -345,14 +339,20 @@ class APIClient:
|
||||
"""Finish connecting to the device."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._connection is not None
|
||||
try:
|
||||
await self._connection.finish_connection(login=login)
|
||||
except Exception:
|
||||
self._connection = None
|
||||
raise
|
||||
await self._execute_connection_coro(
|
||||
self._connection.finish_connection(login=login)
|
||||
)
|
||||
if received_name := self._connection.received_name:
|
||||
self._set_name_from_device(received_name)
|
||||
|
||||
async def _execute_connection_coro(self, coro: Awaitable[None]) -> None:
|
||||
"""Execute a coroutine and reset the _connection if it fails."""
|
||||
try:
|
||||
await coro
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self._connection = None
|
||||
raise
|
||||
|
||||
async def disconnect(self, force: bool = False) -> None:
|
||||
if self._connection is None:
|
||||
return
|
||||
@ -451,7 +451,7 @@ class APIClient:
|
||||
),
|
||||
timeout: float = 10.0,
|
||||
) -> message.Message:
|
||||
message_filter = partial(on_bluetooth_message, address, handle)
|
||||
message_filter = partial(on_bluetooth_handle_message, address, handle)
|
||||
msg_types = (response_type, BluetoothGATTErrorResponse)
|
||||
[resp] = await self._get_connection().send_messages_await_response_complex(
|
||||
(request,),
|
||||
@ -670,11 +670,12 @@ class APIClient:
|
||||
timeout: float,
|
||||
) -> message.Message:
|
||||
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
|
||||
types_with_response = (BluetoothDeviceConnectionResponse, *msg_types)
|
||||
response = await self._bluetooth_device_request(
|
||||
address,
|
||||
request_type,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceConnectionResponse, *msg_types),
|
||||
partial(on_bluetooth_message_types, address, types_with_response),
|
||||
types_with_response,
|
||||
timeout,
|
||||
)
|
||||
self._raise_for_ble_connection_change(address, response, msg_types)
|
||||
@ -706,13 +707,9 @@ class APIClient:
|
||||
timeout: float,
|
||||
) -> message.Message:
|
||||
"""Send a BluetoothDeviceRequest and wait for a response."""
|
||||
req = BluetoothDeviceRequest(address=address, request_type=request_type)
|
||||
[response] = await self._get_connection().send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address,
|
||||
request_type=request_type,
|
||||
),
|
||||
),
|
||||
(req,),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
msg_types,
|
||||
@ -723,42 +720,18 @@ class APIClient:
|
||||
async def bluetooth_gatt_get_services(
|
||||
self, address: int
|
||||
) -> ESPHomeBluetoothGATTServices:
|
||||
append_types = (
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothGATTGetServicesResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
)
|
||||
stop_types = (
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothGATTGetServicesDoneResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
)
|
||||
error_types = (BluetoothGATTErrorResponse, BluetoothDeviceConnectionResponse)
|
||||
append_types = (*error_types, BluetoothGATTGetServicesResponse)
|
||||
stop_types = (*error_types, BluetoothGATTGetServicesDoneResponse)
|
||||
msg_types = (
|
||||
BluetoothGATTGetServicesResponse,
|
||||
BluetoothGATTGetServicesDoneResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
)
|
||||
|
||||
def do_append(
|
||||
msg: BluetoothDeviceConnectionResponse
|
||||
| BluetoothGATTGetServicesResponse
|
||||
| BluetoothGATTGetServicesDoneResponse
|
||||
| BluetoothGATTErrorResponse,
|
||||
) -> bool:
|
||||
return type(msg) in append_types and msg.address == address
|
||||
|
||||
def do_stop(
|
||||
msg: BluetoothDeviceConnectionResponse
|
||||
| BluetoothGATTGetServicesResponse
|
||||
| BluetoothGATTGetServicesDoneResponse
|
||||
| BluetoothGATTErrorResponse,
|
||||
) -> bool:
|
||||
return type(msg) in stop_types and msg.address == address
|
||||
|
||||
resp = await self._get_connection().send_messages_await_response_complex(
|
||||
(BluetoothGATTGetServicesRequest(address=address),),
|
||||
do_append,
|
||||
do_stop,
|
||||
partial(on_bluetooth_message_types, address, append_types),
|
||||
partial(on_bluetooth_message_types, address, stop_types),
|
||||
(*msg_types, BluetoothDeviceConnectionResponse),
|
||||
DEFAULT_BLE_TIMEOUT,
|
||||
)
|
||||
@ -951,7 +924,9 @@ class APIClient:
|
||||
stop: bool = False,
|
||||
) -> None:
|
||||
req = CoverCommandRequest(key=key)
|
||||
apiv = cast(APIVersion, self.api_version)
|
||||
apiv = self.api_version
|
||||
if TYPE_CHECKING:
|
||||
assert apiv is not None
|
||||
if apiv >= APIVersion(1, 1):
|
||||
if position is not None:
|
||||
req.has_position = True
|
||||
@ -1100,7 +1075,9 @@ class APIClient:
|
||||
req.has_custom_fan_mode = True
|
||||
req.custom_fan_mode = custom_fan_mode
|
||||
if preset is not None:
|
||||
apiv = cast(APIVersion, self.api_version)
|
||||
apiv = self.api_version
|
||||
if TYPE_CHECKING:
|
||||
assert apiv is not None
|
||||
if apiv < APIVersion(1, 5):
|
||||
req.has_legacy_away = True
|
||||
req.legacy_away = preset == ClimatePreset.AWAY
|
||||
@ -1183,26 +1160,21 @@ class APIClient:
|
||||
) -> None:
|
||||
req = ExecuteServiceRequest(key=service.key)
|
||||
args = []
|
||||
apiv = self.api_version
|
||||
if TYPE_CHECKING:
|
||||
assert apiv is not None
|
||||
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
|
||||
map_single = USER_SERVICE_MAP_SINGLE
|
||||
map_array = USER_SERVICE_MAP_ARRAY
|
||||
for arg_desc in service.args:
|
||||
arg = ExecuteServiceArgument()
|
||||
val = data[arg_desc.name]
|
||||
apiv = cast(APIVersion, self.api_version)
|
||||
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
|
||||
map_single = {
|
||||
UserServiceArgType.BOOL: "bool_",
|
||||
UserServiceArgType.INT: int_type,
|
||||
UserServiceArgType.FLOAT: "float_",
|
||||
UserServiceArgType.STRING: "string_",
|
||||
}
|
||||
map_array = {
|
||||
UserServiceArgType.BOOL_ARRAY: "bool_array",
|
||||
UserServiceArgType.INT_ARRAY: "int_array",
|
||||
UserServiceArgType.FLOAT_ARRAY: "float_array",
|
||||
UserServiceArgType.STRING_ARRAY: "string_array",
|
||||
}
|
||||
if arg_desc.type in map_array:
|
||||
attr = getattr(arg, map_array[arg_desc.type])
|
||||
attr.extend(val)
|
||||
elif arg_desc.type == UserServiceArgType.INT:
|
||||
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
|
||||
setattr(arg, int_type, val)
|
||||
else:
|
||||
assert arg_desc.type in map_single
|
||||
setattr(arg, map_single[arg_desc.type], val)
|
||||
@ -1276,9 +1248,7 @@ class APIClient:
|
||||
# We hold a reference to the start_task in unsub function
|
||||
# so we don't need to add it to the background tasks.
|
||||
else:
|
||||
stop_task = asyncio.create_task(handle_stop())
|
||||
self._background_tasks.add(stop_task)
|
||||
stop_task.add_done_callback(self._background_tasks.discard)
|
||||
self._create_background_task(handle_stop())
|
||||
|
||||
connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True))
|
||||
|
||||
@ -1300,6 +1270,12 @@ class APIClient:
|
||||
|
||||
return unsub
|
||||
|
||||
def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None:
|
||||
"""Create a background task and add it to the background tasks set."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
def send_voice_assistant_event(
|
||||
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
|
@ -10,6 +10,8 @@ from .api_pb2 import ( # type: ignore
|
||||
BluetoothConnectionsFreeResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
BluetoothGATTGetServicesDoneResponse,
|
||||
BluetoothGATTGetServicesResponse,
|
||||
BluetoothGATTNotifyDataResponse,
|
||||
BluetoothGATTNotifyResponse,
|
||||
BluetoothGATTReadResponse,
|
||||
@ -118,7 +120,7 @@ def on_bluetooth_device_connection_response(
|
||||
connect_future.set_result(None)
|
||||
|
||||
|
||||
def on_bluetooth_message(
|
||||
def on_bluetooth_handle_message(
|
||||
address: int,
|
||||
handle: int,
|
||||
msg: BluetoothGATTErrorResponse
|
||||
@ -127,7 +129,23 @@ def on_bluetooth_message(
|
||||
| BluetoothGATTWriteResponse
|
||||
| BluetoothDeviceConnectionResponse,
|
||||
) -> bool:
|
||||
"""Handle a Bluetooth message."""
|
||||
"""Filter a Bluetooth message for an address and handle."""
|
||||
if type(msg) is BluetoothDeviceConnectionResponse:
|
||||
return bool(msg.address == address)
|
||||
return bool(msg.address == address and msg.handle == handle)
|
||||
|
||||
|
||||
def on_bluetooth_message_types(
|
||||
address: int,
|
||||
msg_types: tuple[type[message.Message]],
|
||||
msg: BluetoothGATTErrorResponse
|
||||
| BluetoothGATTNotifyResponse
|
||||
| BluetoothGATTReadResponse
|
||||
| BluetoothGATTWriteResponse
|
||||
| BluetoothDeviceConnectionResponse
|
||||
| BluetoothGATTGetServicesResponse
|
||||
| BluetoothGATTGetServicesDoneResponse
|
||||
| BluetoothGATTErrorResponse,
|
||||
) -> bool:
|
||||
"""Filter Bluetooth messages of a specific type and address."""
|
||||
return type(msg) in msg_types and bool(msg.address == address)
|
||||
|
Loading…
Reference in New Issue
Block a user