From 23c3959dd297f166770b983a5e0d0a9893d807bb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 07:02:31 -0600 Subject: [PATCH] Handle Bluetooth connection drops in more places --- aioesphomeapi/client.py | 61 +++++++++++++++++++++++------------------ tests/test_client.py | 48 ++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 69cd619..8720ecf 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -86,6 +86,7 @@ from .core import ( BluetoothGATTAPIError, TimeoutAPIError, to_human_readable_address, + to_human_readable_gatt_error, ) from .model import ( AlarmControlPanelCommand, @@ -155,6 +156,10 @@ ExecuteServiceDataType = dict[ ] +class BluetoothConnectionDroppedError(APIConnectionError): + """Raised when a Bluetooth connection is dropped.""" + + def _stringify_or_none(value: str | None) -> str | None: """Convert a string like object to a str or None. @@ -632,26 +637,11 @@ class APIClient: async def bluetooth_device_pair( self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDevicePairing: - def predicate_func( - msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse, - ) -> bool: - if msg.address != address: - return False - if isinstance(msg, BluetoothDeviceConnectionResponse): - raise APIConnectionError( - f"Peripheral changed connections status while pairing: {msg.error}" - ) - return True - return BluetoothDevicePairing.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.PAIR, - predicate_func, - ( - BluetoothDevicePairingResponse, - BluetoothDeviceConnectionResponse, - ), + (BluetoothDevicePairingResponse,), timeout, ) ) @@ -660,10 +650,9 @@ class APIClient: self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDeviceUnpairing: return BluetoothDeviceUnpairing.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.UNPAIR, - lambda msg: msg.address == address, (BluetoothDeviceUnpairingResponse,), timeout, ) @@ -673,10 +662,9 @@ class APIClient: self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT ) -> BluetoothDeviceClearCache: return BluetoothDeviceClearCache.from_pb( - await self._bluetooth_device_request( + await self._bluetooth_device_request_watch_connection( address, BluetoothDeviceRequestType.CLEAR_CACHE, - lambda msg: msg.address == address, (BluetoothDeviceClearCacheResponse,), timeout, ) @@ -694,6 +682,30 @@ class APIClient: timeout, ) + async def _bluetooth_device_request_watch_connection( + self, + address: int, + request_type: BluetoothDeviceRequestType, + msg_types: tuple[type[message.Message], ...], + timeout: float, + ) -> message.Message: + """Send a BluetoothDeviceRequest watch for the connection state to change.""" + response = await self._bluetooth_device_request( + address, + request_type, + lambda msg: msg.address == address, + (BluetoothDeviceConnectionResponse, *msg_types), + timeout, + ) + if type(response) is BluetoothDeviceConnectionResponse: + response_names = ", ".join(t.__name__ for t in msg_types) + raise BluetoothConnectionDroppedError( + "Peripheral changed connection status while waiting for " + f"{response_names}: {to_human_readable_gatt_error(response.error)} " + f"({response.error})" + ) + return response + async def _bluetooth_device_request( self, address: int, @@ -702,6 +714,7 @@ class APIClient: msg_types: tuple[type[message.Message], ...], timeout: float, ) -> message.Message: + """Send a BluetoothDeviceRequest and wait for a response.""" [response] = await self._get_connection().send_messages_await_response_complex( ( BluetoothDeviceRequest( @@ -941,7 +954,6 @@ class APIClient: elif position == 0.0: req.legacy_command = LegacyCoverCommand.CLOSE req.has_legacy_command = True - self._get_connection().send_message(req) async def fan_command( @@ -969,7 +981,6 @@ class APIClient: if direction is not None: req.has_direction = True req.direction = direction - self._get_connection().send_message(req) async def light_command( # pylint: disable=too-many-branches @@ -1027,7 +1038,6 @@ class APIClient: if effect is not None: req.has_effect = True req.effect = effect - self._get_connection().send_message(req) async def switch_command(self, key: int, state: bool) -> None: @@ -1079,7 +1089,6 @@ class APIClient: if custom_preset is not None: req.has_custom_preset = True req.custom_preset = custom_preset - self._get_connection().send_message(req) async def number_command(self, key: int, state: float) -> None: @@ -1109,7 +1118,6 @@ class APIClient: if duration is not None: req.duration = duration req.has_duration = True - self._get_connection().send_message(req) async def button_command(self, key: int) -> None: @@ -1144,7 +1152,6 @@ class APIClient: if media_url is not None: req.media_url = media_url req.has_media_url = True - self._get_connection().send_message(req) async def text_command(self, key: int, state: str) -> None: diff --git a/tests/test_client.py b/tests/test_client.py index 161b326..e45c8a1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import ( VoiceAssistantRequest, VoiceAssistantResponse, ) -from aioesphomeapi.client import APIClient +from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError from aioesphomeapi.connection import APIConnection from aioesphomeapi.core import ( APIConnectionError, @@ -956,8 +956,50 @@ async def test_bluetooth_pair_connection_drops( ) mock_data_received(protocol, generate_plaintext_packet(response)) with pytest.raises( - APIConnectionError, - match="Peripheral changed connections status while pairing: 13", + BluetoothConnectionDroppedError, + match="Peripheral changed connection status while waiting for BluetoothDevicePairingResponse: Invalid attribute length", + ): + await pair_task + + +@pytest.mark.asyncio +async def test_bluetooth_unpair_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_device_unpair.""" + client, connection, transport, protocol = api_client + pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + with pytest.raises( + BluetoothConnectionDroppedError, + match="Peripheral changed connection status while waiting for BluetoothDeviceUnpairingResponse: Invalid attribute length", + ): + await pair_task + + +@pytest.mark.asyncio +async def test_bluetooth_clear_cache_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_device_clear_cache.""" + client, connection, transport, protocol = api_client + pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + with pytest.raises( + BluetoothConnectionDroppedError, + match="Peripheral changed connection status while waiting for BluetoothDeviceClearCacheResponse: Invalid attribute length", ): await pair_task