From 07499907d429fcac0e73cf03aa79b563c9c4f90e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 10:19:39 -0600 Subject: [PATCH 1/6] Refactor execute_service to avoid creating dict in the inner loop (#776) --- aioesphomeapi/client.py | 47 ++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 9ff949e..e337d29 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -5,7 +5,7 @@ import asyncio import logging from collections.abc import Awaitable, Coroutine from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Union from google.protobuf import message @@ -157,6 +157,20 @@ LIST_ENTITIES_MSG_TYPES = ( *LIST_ENTITIES_SERVICES_RESPONSE_TYPES, ) +USER_SERVICE_MAP_ARRAY = { + UserServiceArgType.BOOL_ARRAY: "bool_array", + UserServiceArgType.INT_ARRAY: "int_array", + UserServiceArgType.FLOAT_ARRAY: "float_array", + UserServiceArgType.STRING_ARRAY: "string_array", +} +USER_SERVICE_MAP_SINGLE = { + # Int is a special case because it is handled + # differently depending on the APIVersion + UserServiceArgType.BOOL: "bool_", + UserServiceArgType.FLOAT: "float_", + UserServiceArgType.STRING: "string_", +} + ExecuteServiceDataType = dict[ str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]] @@ -951,7 +965,9 @@ class APIClient: stop: bool = False, ) -> None: req = CoverCommandRequest(key=key) - apiv = cast(APIVersion, self.api_version) + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None if apiv >= APIVersion(1, 1): if position is not None: req.has_position = True @@ -1096,7 +1112,9 @@ class APIClient: req.has_custom_fan_mode = True req.custom_fan_mode = custom_fan_mode if preset is not None: - apiv = cast(APIVersion, self.api_version) + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None if apiv < APIVersion(1, 5): req.has_legacy_away = True req.legacy_away = preset == ClimatePreset.AWAY @@ -1179,26 +1197,21 @@ class APIClient: ) -> None: req = ExecuteServiceRequest(key=service.key) args = [] + apiv = self.api_version + if TYPE_CHECKING: + assert apiv is not None + int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" + map_single = USER_SERVICE_MAP_SINGLE + map_array = USER_SERVICE_MAP_ARRAY for arg_desc in service.args: arg = ExecuteServiceArgument() val = data[arg_desc.name] - apiv = cast(APIVersion, self.api_version) - int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" - map_single = { - UserServiceArgType.BOOL: "bool_", - UserServiceArgType.INT: int_type, - UserServiceArgType.FLOAT: "float_", - UserServiceArgType.STRING: "string_", - } - map_array = { - UserServiceArgType.BOOL_ARRAY: "bool_array", - UserServiceArgType.INT_ARRAY: "int_array", - UserServiceArgType.FLOAT_ARRAY: "float_array", - UserServiceArgType.STRING_ARRAY: "string_array", - } if arg_desc.type in map_array: attr = getattr(arg, map_array[arg_desc.type]) attr.extend(val) + elif arg_desc.type == UserServiceArgType.INT: + int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int" + setattr(arg, int_type, val) else: assert arg_desc.type in map_single setattr(arg, map_single[arg_desc.type], val) From 5c063e22698224ab4be6668ed6833be04ec09e4b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:19:57 +0000 Subject: [PATCH 2/6] Bump version to 19.1.7 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7e9dc18..018a051 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file: long_description = readme_file.read() -VERSION = "19.1.6" +VERSION = "19.1.7" PROJECT_NAME = "aioesphomeapi" PROJECT_PACKAGE_NAME = "aioesphomeapi" PROJECT_LICENSE = "MIT" From d40acb1f85efea4434b97dc6ccc6b91003c66ead Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 10:49:58 -0600 Subject: [PATCH 3/6] Reduce duplicate Bluetooth message filtering code (#777) --- aioesphomeapi/client.py | 52 ++++++++----------------------- aioesphomeapi/client_callbacks.py | 22 +++++++++++-- 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index e337d29..046766a 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -76,8 +76,9 @@ from .client_callbacks import ( on_bluetooth_connections_free_response, on_bluetooth_device_connection_response, on_bluetooth_gatt_notify_data_response, + on_bluetooth_handle_message, on_bluetooth_le_advertising_response, - on_bluetooth_message, + on_bluetooth_message_types, on_home_assistant_service_response, on_state_msg, on_subscribe_home_assistant_state_response, @@ -465,7 +466,7 @@ class APIClient: ), timeout: float = 10.0, ) -> message.Message: - message_filter = partial(on_bluetooth_message, address, handle) + message_filter = partial(on_bluetooth_handle_message, address, handle) msg_types = (response_type, BluetoothGATTErrorResponse) [resp] = await self._get_connection().send_messages_await_response_complex( (request,), @@ -684,11 +685,12 @@ class APIClient: timeout: float, ) -> message.Message: """Send a BluetoothDeviceRequest watch for the connection state to change.""" + types_with_response = (BluetoothDeviceConnectionResponse, *msg_types) response = await self._bluetooth_device_request( address, request_type, - lambda msg: msg.address == address, - (BluetoothDeviceConnectionResponse, *msg_types), + partial(on_bluetooth_message_types, address, types_with_response), + types_with_response, timeout, ) self._raise_for_ble_connection_change(address, response, msg_types) @@ -720,13 +722,9 @@ class APIClient: timeout: float, ) -> message.Message: """Send a BluetoothDeviceRequest and wait for a response.""" + req = BluetoothDeviceRequest(address=address, request_type=request_type) [response] = await self._get_connection().send_messages_await_response_complex( - ( - BluetoothDeviceRequest( - address=address, - request_type=request_type, - ), - ), + (req,), predicate_func, predicate_func, msg_types, @@ -737,42 +735,18 @@ class APIClient: async def bluetooth_gatt_get_services( self, address: int ) -> ESPHomeBluetoothGATTServices: - append_types = ( - BluetoothDeviceConnectionResponse, - BluetoothGATTGetServicesResponse, - BluetoothGATTErrorResponse, - ) - stop_types = ( - BluetoothDeviceConnectionResponse, - BluetoothGATTGetServicesDoneResponse, - BluetoothGATTErrorResponse, - ) + error_types = (BluetoothGATTErrorResponse, BluetoothDeviceConnectionResponse) + append_types = (*error_types, BluetoothGATTGetServicesResponse) + stop_types = (*error_types, BluetoothGATTGetServicesDoneResponse) msg_types = ( BluetoothGATTGetServicesResponse, BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse, ) - - def do_append( - msg: BluetoothDeviceConnectionResponse - | BluetoothGATTGetServicesResponse - | BluetoothGATTGetServicesDoneResponse - | BluetoothGATTErrorResponse, - ) -> bool: - return type(msg) in append_types and msg.address == address - - def do_stop( - msg: BluetoothDeviceConnectionResponse - | BluetoothGATTGetServicesResponse - | BluetoothGATTGetServicesDoneResponse - | BluetoothGATTErrorResponse, - ) -> bool: - return type(msg) in stop_types and msg.address == address - resp = await self._get_connection().send_messages_await_response_complex( (BluetoothGATTGetServicesRequest(address=address),), - do_append, - do_stop, + partial(on_bluetooth_message_types, address, append_types), + partial(on_bluetooth_message_types, address, stop_types), (*msg_types, BluetoothDeviceConnectionResponse), DEFAULT_BLE_TIMEOUT, ) diff --git a/aioesphomeapi/client_callbacks.py b/aioesphomeapi/client_callbacks.py index 4b378ae..f4cc3f2 100644 --- a/aioesphomeapi/client_callbacks.py +++ b/aioesphomeapi/client_callbacks.py @@ -10,6 +10,8 @@ from .api_pb2 import ( # type: ignore BluetoothConnectionsFreeResponse, BluetoothDeviceConnectionResponse, BluetoothGATTErrorResponse, + BluetoothGATTGetServicesDoneResponse, + BluetoothGATTGetServicesResponse, BluetoothGATTNotifyDataResponse, BluetoothGATTNotifyResponse, BluetoothGATTReadResponse, @@ -118,7 +120,7 @@ def on_bluetooth_device_connection_response( connect_future.set_result(None) -def on_bluetooth_message( +def on_bluetooth_handle_message( address: int, handle: int, msg: BluetoothGATTErrorResponse @@ -127,7 +129,23 @@ def on_bluetooth_message( | BluetoothGATTWriteResponse | BluetoothDeviceConnectionResponse, ) -> bool: - """Handle a Bluetooth message.""" + """Filter a Bluetooth message for an address and handle.""" if type(msg) is BluetoothDeviceConnectionResponse: return bool(msg.address == address) return bool(msg.address == address and msg.handle == handle) + + +def on_bluetooth_message_types( + address: int, + msg_types: tuple[type[message.Message]], + msg: BluetoothGATTErrorResponse + | BluetoothGATTNotifyResponse + | BluetoothGATTReadResponse + | BluetoothGATTWriteResponse + | BluetoothDeviceConnectionResponse + | BluetoothGATTGetServicesResponse + | BluetoothGATTGetServicesDoneResponse + | BluetoothGATTErrorResponse, +) -> bool: + """Filter Bluetooth messages of a specific type and address.""" + return type(msg) in msg_types and bool(msg.address == address) From b3a621f8097b0d49823e5b7fd125624b72648cd4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 11:54:29 -0600 Subject: [PATCH 4/6] Make creating background tasks in the client a bound method (#778) --- aioesphomeapi/client.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 046766a..2273ce7 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -1259,9 +1259,7 @@ class APIClient: # We hold a reference to the start_task in unsub function # so we don't need to add it to the background tasks. else: - stop_task = asyncio.create_task(handle_stop()) - self._background_tasks.add(stop_task) - stop_task.add_done_callback(self._background_tasks.discard) + self._create_background_task(handle_stop()) connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True)) @@ -1283,6 +1281,12 @@ class APIClient: return unsub + def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None: + """Create a background task and add it to the background tasks set.""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + def send_voice_assistant_event( self, event_type: VoiceAssistantEventType, data: dict[str, str] | None ) -> None: From d40e046d1ab8e7f5d3de84605d71c2de113d932d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 11:59:54 -0600 Subject: [PATCH 5/6] Reduce duplicate code in client connection setup (#779) --- aioesphomeapi/client.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 2273ce7..2d18367 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -336,19 +336,13 @@ class APIClient: """Start connecting to the device.""" if self._connection is not None: raise APIConnectionError(f"Already connected to {self.log_name}!") - self._connection = APIConnection( self._params, partial(self._on_stop, on_stop), self._debug_enabled, self.log_name, ) - - try: - await self._connection.start_connection() - except Exception: - self._connection = None - raise + await self._execute_connection_coro(self._connection.start_connection()) # If we resolved the address, we should set the log name now if self._connection.resolved_addr_info: self._set_log_name() @@ -360,14 +354,20 @@ class APIClient: """Finish connecting to the device.""" if TYPE_CHECKING: assert self._connection is not None - try: - await self._connection.finish_connection(login=login) - except Exception: - self._connection = None - raise + await self._execute_connection_coro( + self._connection.finish_connection(login=login) + ) if received_name := self._connection.received_name: self._set_name_from_device(received_name) + async def _execute_connection_coro(self, coro: Awaitable[None]) -> None: + """Execute a coroutine and reset the _connection if it fails.""" + try: + await coro + except Exception: # pylint: disable=broad-except + self._connection = None + raise + async def disconnect(self, force: bool = False) -> None: if self._connection is None: return From 5c8370c506e1477981322ed520dbeb9b0252b48a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 12:03:55 -0600 Subject: [PATCH 6/6] Use background task logic for the on_stop callback (#780) --- aioesphomeapi/client.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 2d18367..1e81ba7 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -206,7 +206,6 @@ class APIClient: "cached_name", "_background_tasks", "_loop", - "_on_stop_task", "log_name", ) @@ -253,7 +252,6 @@ class APIClient: self.cached_name: str | None = None self._background_tasks: set[asyncio.Task[Any]] = set() self._loop = asyncio.get_event_loop() - self._on_stop_task: asyncio.Task[None] | None = None self._set_log_name() def set_debug(self, enabled: bool) -> None: @@ -314,20 +312,7 @@ class APIClient: # Hook into on_stop handler to clear connection when stopped self._connection = None if on_stop: - self._on_stop_task = asyncio.create_task( - on_stop(expected_disconnect), - name=f"{self.log_name} aioesphomeapi on_stop", - ) - self._on_stop_task.add_done_callback(self._remove_on_stop_task) - - def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None: - """Remove the stop task. - - We need to do this because the asyncio does not hold - a strong reference to the task, so it can be garbage - collected unexpectedly. - """ - self._on_stop_task = None + self._create_background_task(on_stop(expected_disconnect)) async def start_connection( self,