diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 201d002..f3f848d 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -1,3 +1,4 @@ +# pylint: disable=unidiomatic-typecheck from __future__ import annotations import asyncio @@ -459,10 +460,7 @@ class APIClient: BluetoothDeviceConnectionResponse, ), ) - if ( - type(msg) # pylint: disable=unidiomatic-typecheck - is BluetoothDeviceConnectionResponse - ): + if type(msg) is BluetoothDeviceConnectionResponse: return bool(msg.address == address) return bool(msg.address == address and msg.handle == handle) @@ -488,10 +486,7 @@ class APIClient: timeout, ) - if ( - type(resp) # pylint: disable=unidiomatic-typecheck - is BluetoothGATTErrorResponse - ): + if type(resp) is BluetoothGATTErrorResponse: raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp)) self._raise_for_ble_connection_change(address, resp, msg_types) @@ -717,10 +712,7 @@ class APIClient: msg_types: tuple[type[message.Message], ...], ) -> None: """Raise an exception if the connection status changed.""" - if ( - type(response) # pylint: disable=unidiomatic-typecheck - is not BluetoothDeviceConnectionResponse - ): + if type(response) is not BluetoothDeviceConnectionResponse: return response_names = message_types_to_names(msg_types) human_readable_address = to_human_readable_address(address) @@ -756,31 +748,50 @@ class APIClient: async def bluetooth_gatt_get_services( self, address: int ) -> ESPHomeBluetoothGATTServices: - append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse) - stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse) + append_types = ( + BluetoothDeviceConnectionResponse, + BluetoothGATTGetServicesResponse, + BluetoothGATTErrorResponse, + ) + stop_types = ( + BluetoothDeviceConnectionResponse, + BluetoothGATTGetServicesDoneResponse, + BluetoothGATTErrorResponse, + ) + msg_types = ( + BluetoothGATTGetServicesResponse, + BluetoothGATTGetServicesDoneResponse, + BluetoothGATTErrorResponse, + ) - def do_append(msg: message.Message) -> bool: - return isinstance(msg, append_types) and msg.address == address + def do_append( + msg: BluetoothDeviceConnectionResponse + | BluetoothGATTGetServicesResponse + | BluetoothGATTGetServicesDoneResponse + | BluetoothGATTErrorResponse, + ) -> bool: + return type(msg) in append_types and msg.address == address - def do_stop(msg: message.Message) -> bool: - return isinstance(msg, stop_types) and msg.address == address + def do_stop( + msg: BluetoothDeviceConnectionResponse + | BluetoothGATTGetServicesResponse + | BluetoothGATTGetServicesDoneResponse + | BluetoothGATTErrorResponse, + ) -> bool: + return type(msg) in stop_types and msg.address == address resp = await self._get_connection().send_messages_await_response_complex( (BluetoothGATTGetServicesRequest(address=address),), do_append, do_stop, - ( - BluetoothGATTGetServicesResponse, - BluetoothGATTGetServicesDoneResponse, - BluetoothGATTErrorResponse, - ), + (*msg_types, BluetoothDeviceConnectionResponse), DEFAULT_BLE_TIMEOUT, ) services = [] for msg in resp: + self._raise_for_ble_connection_change(address, msg, msg_types) if isinstance(msg, BluetoothGATTErrorResponse): raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(msg)) - services.extend(BluetoothGATTServices.from_pb(msg).services) return ESPHomeBluetoothGATTServices(address=address, services=services) # type: ignore[call-arg] diff --git a/tests/test_client.py b/tests/test_client.py index 5f700cc..825bcd3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1266,6 +1266,29 @@ async def test_bluetooth_gatt_write_descriptor_without_response( await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0) +@pytest.mark.asyncio +async def test_bluetooth_gatt_get_services_connection_drops( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test connection drop during bluetooth_gatt_get_services.""" + client, connection, transport, protocol = api_client + services_task = asyncio.create_task(client.bluetooth_gatt_get_services(1234)) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, error=13 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + message = ( + "Peripheral 00:00:00:00:04:D2 changed connection status while waiting" + " for BluetoothGATTGetServicesResponse, BluetoothGATTGetServicesDoneResponse, " + "BluetoothGATTErrorResponse: Invalid attribute length" + ) + with pytest.raises(BluetoothConnectionDroppedError, match=message): + await services_task + + @pytest.mark.asyncio async def test_bluetooth_gatt_get_services( api_client: tuple[