From e182f68b427b2f194719ab07cf374b4684e4e372 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Nov 2023 07:42:21 -0600 Subject: [PATCH] Raise BluetoothConnectionDroppedError if connection drops during GATT read/write/notify (#767) --- aioesphomeapi/client.py | 22 ++++++++--- tests/test_client.py | 85 +++++++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 18 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index cc9f101..201d002 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -456,8 +456,14 @@ class APIClient: BluetoothGATTNotifyResponse, BluetoothGATTReadResponse, BluetoothGATTWriteResponse, + BluetoothDeviceConnectionResponse, ), ) + if ( + type(msg) # pylint: disable=unidiomatic-typecheck + is BluetoothDeviceConnectionResponse + ): + return bool(msg.address == address) return bool(msg.address == address and msg.handle == handle) async def _send_bluetooth_message_await_response( @@ -473,11 +479,12 @@ class APIClient: timeout: float = 10.0, ) -> message.Message: message_filter = partial(self._filter_bluetooth_message, address, handle) + msg_types = (response_type, BluetoothGATTErrorResponse) [resp] = await self._get_connection().send_messages_await_response_complex( (request,), message_filter, message_filter, - (response_type, BluetoothGATTErrorResponse), + (*msg_types, BluetoothDeviceConnectionResponse), timeout, ) @@ -487,6 +494,8 @@ class APIClient: ): raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp)) + self._raise_for_ble_connection_change(address, resp, msg_types) + return resp def _unsub_bluetooth_advertisements( @@ -698,11 +707,7 @@ class APIClient: (BluetoothDeviceConnectionResponse, *msg_types), timeout, ) - if ( - type(response) # pylint: disable=unidiomatic-typecheck - is BluetoothDeviceConnectionResponse - ): - self._raise_for_ble_connection_change(address, response, msg_types) + self._raise_for_ble_connection_change(address, response, msg_types) return response def _raise_for_ble_connection_change( @@ -712,6 +717,11 @@ class APIClient: msg_types: tuple[type[message.Message], ...], ) -> None: """Raise an exception if the connection status changed.""" + if ( + type(response) # pylint: disable=unidiomatic-typecheck + is not BluetoothDeviceConnectionResponse + ): + return response_names = message_types_to_names(msg_types) human_readable_address = to_human_readable_address(address) raise BluetoothConnectionDroppedError( diff --git a/tests/test_client.py b/tests/test_client.py index 4c3b271..5f700cc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -959,10 +959,7 @@ async def test_bluetooth_pair_connection_drops( "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" " for BluetoothDevicePairingResponse: Invalid attribute length" ) - with pytest.raises( - BluetoothConnectionDroppedError, - match=message, - ): + with pytest.raises(BluetoothConnectionDroppedError, match=message): await pair_task @@ -984,10 +981,7 @@ async def test_bluetooth_unpair_connection_drops( "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" " for BluetoothDeviceUnpairingResponse: Invalid attribute length" ) - with pytest.raises( - BluetoothConnectionDroppedError, - match=message, - ): + with pytest.raises(BluetoothConnectionDroppedError, match=message): await pair_task @@ -1009,10 +1003,7 @@ async def test_bluetooth_clear_cache_connection_drops( "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" " for BluetoothDeviceClearCacheResponse: Invalid attribute length" ) - with pytest.raises( - BluetoothConnectionDroppedError, - match=message, - ): + with pytest.raises(BluetoothConnectionDroppedError, match=message): await pair_task @@ -1100,6 +1091,28 @@ async def test_bluetooth_gatt_read( assert await read_task == b"1234" +@pytest.mark.asyncio +async def test_bluetooth_gatt_read_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_gatt_read.""" + client, connection, transport, protocol = api_client + read_task = asyncio.create_task(client.bluetooth_gatt_read(1234, 1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothGATTReadResponse, BluetoothGATTErrorResponse: Invalid attribute length" + ) + with pytest.raises(BluetoothConnectionDroppedError, match=message): + await read_task + + @pytest.mark.asyncio async def test_bluetooth_gatt_read_error( api_client: tuple[ @@ -1164,6 +1177,30 @@ async def test_bluetooth_gatt_write( await write_task +@pytest.mark.asyncio +async def test_bluetooth_gatt_write_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_gatt_read.""" + client, connection, transport, protocol = api_client + write_task = asyncio.create_task( + client.bluetooth_gatt_write(1234, 1234, b"1234", True) + ) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothGATTWriteResponse, BluetoothGATTErrorResponse: Invalid attribute length" + ) + with pytest.raises(BluetoothConnectionDroppedError, match=message): + await write_task + + @pytest.mark.asyncio async def test_bluetooth_gatt_write_without_response( api_client: tuple[ @@ -1290,6 +1327,30 @@ async def test_bluetooth_gatt_get_services_errors( await services_task +@pytest.mark.asyncio +async def test_bluetooth_gatt_start_notify_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_gatt_start_notify.""" + client, connection, transport, protocol = api_client + notify_task = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 1, lambda handle, data: None) + ) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothGATTNotifyResponse, BluetoothGATTErrorResponse: Invalid attribute length" + ) + with pytest.raises(BluetoothConnectionDroppedError, match=message): + await notify_task + + @pytest.mark.asyncio async def test_bluetooth_gatt_start_notify( api_client: tuple[