From 3e920df478818271f430bb77cd4f252d07726e39 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 07:23:21 -0600 Subject: [PATCH] Handle Bluetooth connection drops in more places (#766) --- aioesphomeapi/__init__.py | 1 + aioesphomeapi/client.py | 77 +++++++++++++++++++++++-------------- aioesphomeapi/connection.py | 4 +- aioesphomeapi/core.py | 4 ++ aioesphomeapi/model.py | 6 +++ tests/test_client.py | 60 +++++++++++++++++++++++++++-- 6 files changed, 119 insertions(+), 33 deletions(-) diff --git a/aioesphomeapi/__init__.py b/aioesphomeapi/__init__.py index 2ac018e..45b3138 100644 --- a/aioesphomeapi/__init__.py +++ b/aioesphomeapi/__init__.py @@ -11,6 +11,7 @@ from .core import ( MESSAGE_TYPE_TO_PROTO, APIConnectionError, BadNameAPIError, + BluetoothConnectionDroppedError, HandshakeAPIError, InvalidAuthAPIError, InvalidEncryptionKeyAPIError, diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 69cd619..cc9f101 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -83,9 +83,11 @@ from .client_callbacks import ( from .connection import APIConnection, ConnectionParams, handle_timeout from .core import ( APIConnectionError, + BluetoothConnectionDroppedError, BluetoothGATTAPIError, TimeoutAPIError, to_human_readable_address, + to_human_readable_gatt_error, ) from .model import ( AlarmControlPanelCommand, @@ -118,7 +120,11 @@ from .model import ( UserServiceArgType, ) from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel -from .model import VoiceAssistantCommand, VoiceAssistantEventType +from .model import ( + VoiceAssistantCommand, + VoiceAssistantEventType, + message_types_to_names, +) from .model_conversions import ( LIST_ENTITIES_SERVICES_RESPONSE_TYPES, SUBSCRIBE_STATES_RESPONSE_TYPES, @@ -632,26 +638,11 @@ class APIClient: async def bluetooth_device_pair( self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDevicePairing: - def predicate_func( - msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse, - ) -> bool: - if msg.address != address: - return False - if isinstance(msg, BluetoothDeviceConnectionResponse): - raise APIConnectionError( - f"Peripheral changed connections status while pairing: {msg.error}" - ) - return True - return BluetoothDevicePairing.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.PAIR, - predicate_func, - ( - BluetoothDevicePairingResponse, - BluetoothDeviceConnectionResponse, - ), + (BluetoothDevicePairingResponse,), timeout, ) ) @@ -660,10 +651,9 @@ class APIClient: self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDeviceUnpairing: return BluetoothDeviceUnpairing.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.UNPAIR, - lambda msg: msg.address == address, (BluetoothDeviceUnpairingResponse,), timeout, ) @@ -673,10 +663,9 @@ class APIClient: self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDeviceClearCache: return BluetoothDeviceClearCache.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.CLEAR_CACHE, - lambda msg: msg.address == address, (BluetoothDeviceClearCacheResponse,), timeout, ) @@ -694,6 +683,43 @@ class APIClient: timeout, ) + async def _bluetooth_device_request_watch_connection( + self, + address: int, + request_type: BluetoothDeviceRequestType, + msg_types: tuple[type[message.Message], ...], + timeout: float, + ) -> message.Message: + """Send a BluetoothDeviceRequest watch for the connection state to change.""" + response = await self._bluetooth_device_request( + address, + request_type, + lambda msg: msg.address == address, + (BluetoothDeviceConnectionResponse, *msg_types), + timeout, + ) + if ( + type(response) # pylint: disable=unidiomatic-typecheck + is BluetoothDeviceConnectionResponse + ): + self._raise_for_ble_connection_change(address, response, msg_types) + return response + + def _raise_for_ble_connection_change( + self, + address: int, + response: BluetoothDeviceConnectionResponse, + msg_types: tuple[type[message.Message], ...], + ) -> None: + """Raise an exception if the connection status changed.""" + response_names = message_types_to_names(msg_types) + human_readable_address = to_human_readable_address(address) + raise BluetoothConnectionDroppedError( + f"Peripheral {human_readable_address} changed connection status while waiting for " + f"{response_names}: {to_human_readable_gatt_error(response.error)} " + f"({response.error})" + ) + async def _bluetooth_device_request( self, address: int, @@ -702,6 +728,7 @@ class APIClient: msg_types: tuple[type[message.Message], ...], timeout: float, ) -> message.Message: + """Send a BluetoothDeviceRequest and wait for a response.""" [response] = await self._get_connection().send_messages_await_response_complex( ( BluetoothDeviceRequest( @@ -941,7 +968,6 @@ class APIClient: elif position == 0.0: req.legacy_command = LegacyCoverCommand.CLOSE req.has_legacy_command = True - self._get_connection().send_message(req) async def fan_command( @@ -969,7 +995,6 @@ class APIClient: if direction is not None: req.has_direction = True req.direction = direction - self._get_connection().send_message(req) async def light_command( # pylint: disable=too-many-branches @@ -1027,7 +1052,6 @@ class APIClient: if effect is not None: req.has_effect = True req.effect = effect - self._get_connection().send_message(req) async def switch_command(self, key: int, state: bool) -> None: @@ -1079,7 +1103,6 @@ class APIClient: if custom_preset is not None: req.has_custom_preset = True req.custom_preset = custom_preset - self._get_connection().send_message(req) async def number_command(self, key: int, state: float) -> None: @@ -1109,7 +1132,6 @@ class APIClient: if duration is not None: req.duration = duration req.has_duration = True - self._get_connection().send_message(req) async def button_command(self, key: int) -> None: @@ -1144,7 +1166,6 @@ class APIClient: if media_url is not None: req.media_url = media_url req.has_media_url = True - self._get_connection().send_message(req) async def text_command(self, key: int, state: str) -> None: diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 11f2df8..2a30bcd 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -49,7 +49,7 @@ from .core import ( TimeoutAPIError, UnhandledAPIConnectionError, ) -from .model import APIVersion +from .model import APIVersion, message_types_to_names from .zeroconf import ZeroconfManager if sys.version_info[:2] < (3, 11): @@ -758,7 +758,7 @@ class APIConnection: await fut except asyncio_TimeoutError as err: timeout_expired = True - response_names = ", ".join(t.__name__ for t in msg_types) + response_names = message_types_to_names(msg_types) raise TimeoutAPIError( f"Timeout waiting for {response_names} after {timeout}s" ) from err diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index eeb61ef..6ed5ead 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -228,6 +228,10 @@ class UnhandledAPIConnectionError(APIConnectionError): pass +class BluetoothConnectionDroppedError(APIConnectionError): + """Raised when a Bluetooth connection is dropped.""" + + def to_human_readable_address(address: int) -> str: """Convert a MAC address to a human readable format.""" return ":".join(TWO_CHAR.findall(f"{address:012X}")) diff --git a/aioesphomeapi/model.py b/aioesphomeapi/model.py index f9826f6..ac3dffe 100644 --- a/aioesphomeapi/model.py +++ b/aioesphomeapi/model.py @@ -8,6 +8,8 @@ from functools import cache, lru_cache, partial from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from uuid import UUID +from google.protobuf import message + from .util import fix_float_single_double_conversion if sys.version_info[:2] < (3, 10): @@ -1166,3 +1168,7 @@ def build_unique_id(formatted_mac: str, entity_info: EntityInfo) -> str: """ # -- return f"{formatted_mac}-{_TYPE_TO_NAME[type(entity_info)]}-{entity_info.object_id}" + + +def message_types_to_names(msg_types: Iterable[type[message.Message]]) -> str: + return ", ".join(t.__name__ for t in msg_types) diff --git a/tests/test_client.py b/tests/test_client.py index 161b326..4c3b271 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import ( VoiceAssistantRequest, VoiceAssistantResponse, ) -from aioesphomeapi.client import APIClient +from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError from aioesphomeapi.connection import APIConnection from aioesphomeapi.core import ( APIConnectionError, @@ -955,9 +955,63 @@ async def test_bluetooth_pair_connection_drops( address=1234, connected=False, error=13 ) mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothDevicePairingResponse: Invalid attribute length" + ) with pytest.raises( - APIConnectionError, - match="Peripheral changed connections status while pairing: 13", + BluetoothConnectionDroppedError, + match=message, + ): + await pair_task + + +@pytest.mark.asyncio +async def test_bluetooth_unpair_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_device_unpair.""" + client, connection, transport, protocol = api_client + pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothDeviceUnpairingResponse: Invalid attribute length" + ) + with pytest.raises( + BluetoothConnectionDroppedError, + match=message, + ): + await pair_task + + +@pytest.mark.asyncio +async def test_bluetooth_clear_cache_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_device_clear_cache.""" + client, connection, transport, protocol = api_client + pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothDeviceClearCacheResponse: Invalid attribute length" + ) + with pytest.raises( + BluetoothConnectionDroppedError, + match=message, ): await pair_task