Handle Bluetooth connection drops in more places (#766)

This commit is contained in:
J. Nick Koston 2023-11-28 07:23:21 -06:00 committed by GitHub
parent 72a8f70bcd
commit 3e920df478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 119 additions and 33 deletions

View File

@ -11,6 +11,7 @@ from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
BadNameAPIError,
BluetoothConnectionDroppedError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,

View File

@ -83,9 +83,11 @@ from .client_callbacks import (
from .connection import APIConnection, ConnectionParams, handle_timeout
from .core import (
APIConnectionError,
BluetoothConnectionDroppedError,
BluetoothGATTAPIError,
TimeoutAPIError,
to_human_readable_address,
to_human_readable_gatt_error,
)
from .model import (
AlarmControlPanelCommand,
@ -118,7 +120,11 @@ from .model import (
UserServiceArgType,
)
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
from .model import VoiceAssistantCommand, VoiceAssistantEventType
from .model import (
VoiceAssistantCommand,
VoiceAssistantEventType,
message_types_to_names,
)
from .model_conversions import (
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
SUBSCRIBE_STATES_RESPONSE_TYPES,
@ -632,26 +638,11 @@ class APIClient:
async def bluetooth_device_pair(
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDevicePairing:
def predicate_func(
msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse,
) -> bool:
if msg.address != address:
return False
if isinstance(msg, BluetoothDeviceConnectionResponse):
raise APIConnectionError(
f"Peripheral changed connections status while pairing: {msg.error}"
)
return True
return BluetoothDevicePairing.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.PAIR,
predicate_func,
(
BluetoothDevicePairingResponse,
BluetoothDeviceConnectionResponse,
),
(BluetoothDevicePairingResponse,),
timeout,
)
)
@ -660,10 +651,9 @@ class APIClient:
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceUnpairing:
return BluetoothDeviceUnpairing.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.UNPAIR,
lambda msg: msg.address == address,
(BluetoothDeviceUnpairingResponse,),
timeout,
)
@ -673,10 +663,9 @@ class APIClient:
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
) -> BluetoothDeviceClearCache:
return BluetoothDeviceClearCache.from_pb(
await self._bluetooth_device_request(
await self._bluetooth_device_request_watch_connection(
address,
BluetoothDeviceRequestType.CLEAR_CACHE,
lambda msg: msg.address == address,
(BluetoothDeviceClearCacheResponse,),
timeout,
)
@ -694,6 +683,43 @@ class APIClient:
timeout,
)
async def _bluetooth_device_request_watch_connection(
self,
address: int,
request_type: BluetoothDeviceRequestType,
msg_types: tuple[type[message.Message], ...],
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
response = await self._bluetooth_device_request(
address,
request_type,
lambda msg: msg.address == address,
(BluetoothDeviceConnectionResponse, *msg_types),
timeout,
)
if (
type(response) # pylint: disable=unidiomatic-typecheck
is BluetoothDeviceConnectionResponse
):
self._raise_for_ble_connection_change(address, response, msg_types)
return response
def _raise_for_ble_connection_change(
self,
address: int,
response: BluetoothDeviceConnectionResponse,
msg_types: tuple[type[message.Message], ...],
) -> None:
"""Raise an exception if the connection status changed."""
response_names = message_types_to_names(msg_types)
human_readable_address = to_human_readable_address(address)
raise BluetoothConnectionDroppedError(
f"Peripheral {human_readable_address} changed connection status while waiting for "
f"{response_names}: {to_human_readable_gatt_error(response.error)} "
f"({response.error})"
)
async def _bluetooth_device_request(
self,
address: int,
@ -702,6 +728,7 @@ class APIClient:
msg_types: tuple[type[message.Message], ...],
timeout: float,
) -> message.Message:
"""Send a BluetoothDeviceRequest and wait for a response."""
[response] = await self._get_connection().send_messages_await_response_complex(
(
BluetoothDeviceRequest(
@ -941,7 +968,6 @@ class APIClient:
elif position == 0.0:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
self._get_connection().send_message(req)
async def fan_command(
@ -969,7 +995,6 @@ class APIClient:
if direction is not None:
req.has_direction = True
req.direction = direction
self._get_connection().send_message(req)
async def light_command( # pylint: disable=too-many-branches
@ -1027,7 +1052,6 @@ class APIClient:
if effect is not None:
req.has_effect = True
req.effect = effect
self._get_connection().send_message(req)
async def switch_command(self, key: int, state: bool) -> None:
@ -1079,7 +1103,6 @@ class APIClient:
if custom_preset is not None:
req.has_custom_preset = True
req.custom_preset = custom_preset
self._get_connection().send_message(req)
async def number_command(self, key: int, state: float) -> None:
@ -1109,7 +1132,6 @@ class APIClient:
if duration is not None:
req.duration = duration
req.has_duration = True
self._get_connection().send_message(req)
async def button_command(self, key: int) -> None:
@ -1144,7 +1166,6 @@ class APIClient:
if media_url is not None:
req.media_url = media_url
req.has_media_url = True
self._get_connection().send_message(req)
async def text_command(self, key: int, state: str) -> None:

View File

@ -49,7 +49,7 @@ from .core import (
TimeoutAPIError,
UnhandledAPIConnectionError,
)
from .model import APIVersion
from .model import APIVersion, message_types_to_names
from .zeroconf import ZeroconfManager
if sys.version_info[:2] < (3, 11):
@ -758,7 +758,7 @@ class APIConnection:
await fut
except asyncio_TimeoutError as err:
timeout_expired = True
response_names = ", ".join(t.__name__ for t in msg_types)
response_names = message_types_to_names(msg_types)
raise TimeoutAPIError(
f"Timeout waiting for {response_names} after {timeout}s"
) from err

View File

@ -228,6 +228,10 @@ class UnhandledAPIConnectionError(APIConnectionError):
pass
class BluetoothConnectionDroppedError(APIConnectionError):
"""Raised when a Bluetooth connection is dropped."""
def to_human_readable_address(address: int) -> str:
"""Convert a MAC address to a human readable format."""
return ":".join(TWO_CHAR.findall(f"{address:012X}"))

View File

@ -8,6 +8,8 @@ from functools import cache, lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from uuid import UUID
from google.protobuf import message
from .util import fix_float_single_double_conversion
if sys.version_info[:2] < (3, 10):
@ -1166,3 +1168,7 @@ def build_unique_id(formatted_mac: str, entity_info: EntityInfo) -> str:
"""
# <mac>-<entity type>-<object_id>
return f"{formatted_mac}-{_TYPE_TO_NAME[type(entity_info)]}-{entity_info.object_id}"
def message_types_to_names(msg_types: Iterable[type[message.Message]]) -> str:
return ", ".join(t.__name__ for t in msg_types)

View File

@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import (
VoiceAssistantRequest,
VoiceAssistantResponse,
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import (
APIConnectionError,
@ -955,9 +955,63 @@ async def test_bluetooth_pair_connection_drops(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
message = (
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
" for BluetoothDevicePairingResponse: Invalid attribute length"
)
with pytest.raises(
APIConnectionError,
match="Peripheral changed connections status while pairing: 13",
BluetoothConnectionDroppedError,
match=message,
):
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_unpair_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_device_unpair."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
message = (
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
" for BluetoothDeviceUnpairingResponse: Invalid attribute length"
)
with pytest.raises(
BluetoothConnectionDroppedError,
match=message,
):
await pair_task
@pytest.mark.asyncio
async def test_bluetooth_clear_cache_connection_drops(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test connection drop during bluetooth_device_clear_cache."""
client, connection, transport, protocol = api_client
pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, error=13
)
mock_data_received(protocol, generate_plaintext_packet(response))
message = (
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
" for BluetoothDeviceClearCacheResponse: Invalid attribute length"
)
with pytest.raises(
BluetoothConnectionDroppedError,
match=message,
):
await pair_task