Raise BluetoothConnectionDroppedError if connection drops while getting GATT services (#768)

This commit is contained in:
J. Nick Koston 2023-11-28 08:03:08 -06:00 committed by GitHub
parent e182f68b42
commit 176c7bc4b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 24 deletions

View File

@ -1,3 +1,4 @@
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import asyncio
@ -459,10 +460,7 @@ class APIClient:
BluetoothDeviceConnectionResponse,
),
)
if (
type(msg) # pylint: disable=unidiomatic-typecheck
is BluetoothDeviceConnectionResponse
):
if type(msg) is BluetoothDeviceConnectionResponse:
return bool(msg.address == address)
return bool(msg.address == address and msg.handle == handle)
@ -488,10 +486,7 @@ class APIClient:
timeout,
)
if (
type(resp) # pylint: disable=unidiomatic-typecheck
is BluetoothGATTErrorResponse
):
if type(resp) is BluetoothGATTErrorResponse:
raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(resp))
self._raise_for_ble_connection_change(address, resp, msg_types)
@ -717,10 +712,7 @@ class APIClient:
msg_types: tuple[type[message.Message], ...],
) -> None:
"""Raise an exception if the connection status changed."""
if (
type(response) # pylint: disable=unidiomatic-typecheck
is not BluetoothDeviceConnectionResponse
):
if type(response) is not BluetoothDeviceConnectionResponse:
return
response_names = message_types_to_names(msg_types)
human_readable_address = to_human_readable_address(address)
@ -756,31 +748,50 @@ class APIClient:
async def bluetooth_gatt_get_services(
self, address: int
) -> ESPHomeBluetoothGATTServices:
append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse)
append_types = (
BluetoothDeviceConnectionResponse,
BluetoothGATTGetServicesResponse,
BluetoothGATTErrorResponse,
)
stop_types = (
BluetoothDeviceConnectionResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
msg_types = (
BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
def do_append(msg: message.Message) -> bool:
return isinstance(msg, append_types) and msg.address == address
def do_append(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in append_types and msg.address == address
def do_stop(msg: message.Message) -> bool:
return isinstance(msg, stop_types) and msg.address == address
def do_stop(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in stop_types and msg.address == address
resp = await self._get_connection().send_messages_await_response_complex(
(BluetoothGATTGetServicesRequest(address=address),),
do_append,
do_stop,
(
BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
),
(*msg_types, BluetoothDeviceConnectionResponse),
DEFAULT_BLE_TIMEOUT,
)
services = []
for msg in resp:
self._raise_for_ble_connection_change(address, msg, msg_types)
if isinstance(msg, BluetoothGATTErrorResponse):
raise BluetoothGATTAPIError(BluetoothGATTError.from_pb(msg))
services.extend(BluetoothGATTServices.from_pb(msg).services)
return ESPHomeBluetoothGATTServices(address=address, services=services) # type: ignore[call-arg]

View File

@ -1266,6 +1266,29 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
@pytest.mark.asyncio
async def test_bluetooth_gatt_get_services_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_gatt_get_services."""
client, connection, transport, protocol = api_client
services_task = asyncio.create_task(client.bluetooth_gatt_get_services(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
message = (
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
" for BluetoothGATTGetServicesResponse, BluetoothGATTGetServicesDoneResponse, "
"BluetoothGATTErrorResponse: Invalid attribute length"
)
with pytest.raises(BluetoothConnectionDroppedError, match=message):
await services_task
@pytest.mark.asyncio
async def test_bluetooth_gatt_get_services(
api_client: tuple[