Handle Bluetooth connection drops in more places

This commit is contained in:
J. Nick Koston 2023-11-28 07:02:31 -06:00
parent 72a8f70bcd
commit 23c3959dd2
No known key found for this signature in database
2 changed files with 79 additions and 30 deletions

View File

@ -86,6 +86,7 @@ from .core import (
BluetoothGATTAPIError,
TimeoutAPIError,
to_human_readable_address,
to_human_readable_gatt_error,
)
from .model import (
AlarmControlPanelCommand,
@ -155,6 +156,10 @@ ExecuteServiceDataType = dict[
]
class BluetoothConnectionDroppedError(APIConnectionError):
"""Raised when a Bluetooth connection is dropped."""
def _stringify_or_none(value: str | None) -> str | None:
"""Convert a string like object to a str or None.
@ -632,26 +637,11 @@ class APIClient:
async def bluetooth_device_pair(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDevicePairing:
def predicate_func(
msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse,
) -> bool:
if msg.address != address:
return False
if isinstance(msg, BluetoothDeviceConnectionResponse):
raise APIConnectionError(
f"Peripheral changed connections status while pairing: {msg.error}"
)
return True
return BluetoothDevicePairing.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.PAIR,
predicate_func,
(
BluetoothDevicePairingResponse,
BluetoothDeviceConnectionResponse,
),
(BluetoothDevicePairingResponse,),
timeout,
)
)
@ -660,10 +650,9 @@ class APIClient:
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceUnpairing:
return BluetoothDeviceUnpairing.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.UNPAIR,
lambda msg: msg.address == address,
(BluetoothDeviceUnpairingResponse,),
timeout,
)
@ -673,10 +662,9 @@ class APIClient:
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceClearCache:
return BluetoothDeviceClearCache.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.CLEAR_CACHE,
lambda msg: msg.address == address,
(BluetoothDeviceClearCacheResponse,),
timeout,
)
@ -694,6 +682,30 @@ class APIClient:
timeout,
)
async def _bluetooth_device_request_watch_connection(
self,
address: int,
request_type: BluetoothDeviceRequestType,
msg_types: tuple[type[message.Message], ...],
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
response = await self._bluetooth_device_request(
address,
request_type,
lambda msg: msg.address == address,
(BluetoothDeviceConnectionResponse, *msg_types),
timeout,
)
if type(response) is BluetoothDeviceConnectionResponse:
response_names = ", ".join(t.__name__ for t in msg_types)
raise BluetoothConnectionDroppedError(
"Peripheral changed connection status while waiting for "
f"{response_names}: {to_human_readable_gatt_error(response.error)} "
f"({response.error})"
)
return response
async def _bluetooth_device_request(
self,
address: int,
@ -702,6 +714,7 @@ class APIClient:
msg_types: tuple[type[message.Message], ...],
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest and wait for a response."""
[response] = await self._get_connection().send_messages_await_response_complex(
(
BluetoothDeviceRequest(
@ -941,7 +954,6 @@ class APIClient:
elif position == 0.0:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
self._get_connection().send_message(req)
async def fan_command(
@ -969,7 +981,6 @@ class APIClient:
if direction is not None:
req.has_direction = True
req.direction = direction
self._get_connection().send_message(req)
async def light_command( # pylint: disable=too-many-branches
@ -1027,7 +1038,6 @@ class APIClient:
if effect is not None:
req.has_effect = True
req.effect = effect
self._get_connection().send_message(req)
async def switch_command(self, key: int, state: bool) -> None:
@ -1079,7 +1089,6 @@ class APIClient:
if custom_preset is not None:
req.has_custom_preset = True
req.custom_preset = custom_preset
self._get_connection().send_message(req)
async def number_command(self, key: int, state: float) -> None:
@ -1109,7 +1118,6 @@ class APIClient:
if duration is not None:
req.duration = duration
req.has_duration = True
self._get_connection().send_message(req)
async def button_command(self, key: int) -> None:
@ -1144,7 +1152,6 @@ class APIClient:
if media_url is not None:
req.media_url = media_url
req.has_media_url = True
self._get_connection().send_message(req)
async def text_command(self, key: int, state: str) -> None:

View File

@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import (
VoiceAssistantRequest,
VoiceAssistantResponse,
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import (
APIConnectionError,
@ -956,8 +956,50 @@ async def test_bluetooth_pair_connection_drops(
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(
APIConnectionError,
match="Peripheral changed connections status while pairing: 13",
BluetoothConnectionDroppedError,
match="Peripheral changed connection status while waiting for BluetoothDevicePairingResponse: Invalid attribute length",
):
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_unpair_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_device_unpair."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(
BluetoothConnectionDroppedError,
match="Peripheral changed connection status while waiting for BluetoothDeviceUnpairingResponse: Invalid attribute length",
):
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_clear_cache_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_device_clear_cache."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(
BluetoothConnectionDroppedError,
match="Peripheral changed connection status while waiting for BluetoothDeviceClearCacheResponse: Invalid attribute length",
):
await pair_task