mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-04-08 19:06:02 +02:00
Handle Bluetooth connection drops in more places (#766)
This commit is contained in:
parent
72a8f70bcd
commit
3e920df478
@ -11,6 +11,7 @@ from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
BluetoothConnectionDroppedError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
|
@ -83,9 +83,11 @@ from .client_callbacks import (
|
||||
from .connection import APIConnection, ConnectionParams, handle_timeout
|
||||
from .core import (
|
||||
APIConnectionError,
|
||||
BluetoothConnectionDroppedError,
|
||||
BluetoothGATTAPIError,
|
||||
TimeoutAPIError,
|
||||
to_human_readable_address,
|
||||
to_human_readable_gatt_error,
|
||||
)
|
||||
from .model import (
|
||||
AlarmControlPanelCommand,
|
||||
@ -118,7 +120,11 @@ from .model import (
|
||||
UserServiceArgType,
|
||||
)
|
||||
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
|
||||
from .model import VoiceAssistantCommand, VoiceAssistantEventType
|
||||
from .model import (
|
||||
VoiceAssistantCommand,
|
||||
VoiceAssistantEventType,
|
||||
message_types_to_names,
|
||||
)
|
||||
from .model_conversions import (
|
||||
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
|
||||
SUBSCRIBE_STATES_RESPONSE_TYPES,
|
||||
@ -632,26 +638,11 @@ class APIClient:
|
||||
async def bluetooth_device_pair(
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDevicePairing:
|
||||
def predicate_func(
|
||||
msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse,
|
||||
) -> bool:
|
||||
if msg.address != address:
|
||||
return False
|
||||
if isinstance(msg, BluetoothDeviceConnectionResponse):
|
||||
raise APIConnectionError(
|
||||
f"Peripheral changed connections status while pairing: {msg.error}"
|
||||
)
|
||||
return True
|
||||
|
||||
return BluetoothDevicePairing.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
await self._bluetooth_device_request_watch_connection(
|
||||
address,
|
||||
BluetoothDeviceRequestType.PAIR,
|
||||
predicate_func,
|
||||
(
|
||||
BluetoothDevicePairingResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
),
|
||||
(BluetoothDevicePairingResponse,),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
@ -660,10 +651,9 @@ class APIClient:
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceUnpairing:
|
||||
return BluetoothDeviceUnpairing.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
await self._bluetooth_device_request_watch_connection(
|
||||
address,
|
||||
BluetoothDeviceRequestType.UNPAIR,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceUnpairingResponse,),
|
||||
timeout,
|
||||
)
|
||||
@ -673,10 +663,9 @@ class APIClient:
|
||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||
) -> BluetoothDeviceClearCache:
|
||||
return BluetoothDeviceClearCache.from_pb(
|
||||
await self._bluetooth_device_request(
|
||||
await self._bluetooth_device_request_watch_connection(
|
||||
address,
|
||||
BluetoothDeviceRequestType.CLEAR_CACHE,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceClearCacheResponse,),
|
||||
timeout,
|
||||
)
|
||||
@ -694,6 +683,43 @@ class APIClient:
|
||||
timeout,
|
||||
)
|
||||
|
||||
async def _bluetooth_device_request_watch_connection(
|
||||
self,
|
||||
address: int,
|
||||
request_type: BluetoothDeviceRequestType,
|
||||
msg_types: tuple[type[message.Message], ...],
|
||||
timeout: float,
|
||||
) -> message.Message:
|
||||
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
|
||||
response = await self._bluetooth_device_request(
|
||||
address,
|
||||
request_type,
|
||||
lambda msg: msg.address == address,
|
||||
(BluetoothDeviceConnectionResponse, *msg_types),
|
||||
timeout,
|
||||
)
|
||||
if (
|
||||
type(response) # pylint: disable=unidiomatic-typecheck
|
||||
is BluetoothDeviceConnectionResponse
|
||||
):
|
||||
self._raise_for_ble_connection_change(address, response, msg_types)
|
||||
return response
|
||||
|
||||
def _raise_for_ble_connection_change(
|
||||
self,
|
||||
address: int,
|
||||
response: BluetoothDeviceConnectionResponse,
|
||||
msg_types: tuple[type[message.Message], ...],
|
||||
) -> None:
|
||||
"""Raise an exception if the connection status changed."""
|
||||
response_names = message_types_to_names(msg_types)
|
||||
human_readable_address = to_human_readable_address(address)
|
||||
raise BluetoothConnectionDroppedError(
|
||||
f"Peripheral {human_readable_address} changed connection status while waiting for "
|
||||
f"{response_names}: {to_human_readable_gatt_error(response.error)} "
|
||||
f"({response.error})"
|
||||
)
|
||||
|
||||
async def _bluetooth_device_request(
|
||||
self,
|
||||
address: int,
|
||||
@ -702,6 +728,7 @@ class APIClient:
|
||||
msg_types: tuple[type[message.Message], ...],
|
||||
timeout: float,
|
||||
) -> message.Message:
|
||||
"""Send a BluetoothDeviceRequest and wait for a response."""
|
||||
[response] = await self._get_connection().send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
@ -941,7 +968,6 @@ class APIClient:
|
||||
elif position == 0.0:
|
||||
req.legacy_command = LegacyCoverCommand.CLOSE
|
||||
req.has_legacy_command = True
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def fan_command(
|
||||
@ -969,7 +995,6 @@ class APIClient:
|
||||
if direction is not None:
|
||||
req.has_direction = True
|
||||
req.direction = direction
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def light_command( # pylint: disable=too-many-branches
|
||||
@ -1027,7 +1052,6 @@ class APIClient:
|
||||
if effect is not None:
|
||||
req.has_effect = True
|
||||
req.effect = effect
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def switch_command(self, key: int, state: bool) -> None:
|
||||
@ -1079,7 +1103,6 @@ class APIClient:
|
||||
if custom_preset is not None:
|
||||
req.has_custom_preset = True
|
||||
req.custom_preset = custom_preset
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def number_command(self, key: int, state: float) -> None:
|
||||
@ -1109,7 +1132,6 @@ class APIClient:
|
||||
if duration is not None:
|
||||
req.duration = duration
|
||||
req.has_duration = True
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def button_command(self, key: int) -> None:
|
||||
@ -1144,7 +1166,6 @@ class APIClient:
|
||||
if media_url is not None:
|
||||
req.media_url = media_url
|
||||
req.has_media_url = True
|
||||
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
async def text_command(self, key: int, state: str) -> None:
|
||||
|
@ -49,7 +49,7 @@ from .core import (
|
||||
TimeoutAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
)
|
||||
from .model import APIVersion
|
||||
from .model import APIVersion, message_types_to_names
|
||||
from .zeroconf import ZeroconfManager
|
||||
|
||||
if sys.version_info[:2] < (3, 11):
|
||||
@ -758,7 +758,7 @@ class APIConnection:
|
||||
await fut
|
||||
except asyncio_TimeoutError as err:
|
||||
timeout_expired = True
|
||||
response_names = ", ".join(t.__name__ for t in msg_types)
|
||||
response_names = message_types_to_names(msg_types)
|
||||
raise TimeoutAPIError(
|
||||
f"Timeout waiting for {response_names} after {timeout}s"
|
||||
) from err
|
||||
|
@ -228,6 +228,10 @@ class UnhandledAPIConnectionError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class BluetoothConnectionDroppedError(APIConnectionError):
|
||||
"""Raised when a Bluetooth connection is dropped."""
|
||||
|
||||
|
||||
def to_human_readable_address(address: int) -> str:
|
||||
"""Convert a MAC address to a human readable format."""
|
||||
return ":".join(TWO_CHAR.findall(f"{address:012X}"))
|
||||
|
@ -8,6 +8,8 @@ from functools import cache, lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
|
||||
from uuid import UUID
|
||||
|
||||
from google.protobuf import message
|
||||
|
||||
from .util import fix_float_single_double_conversion
|
||||
|
||||
if sys.version_info[:2] < (3, 10):
|
||||
@ -1166,3 +1168,7 @@ def build_unique_id(formatted_mac: str, entity_info: EntityInfo) -> str:
|
||||
"""
|
||||
# <mac>-<entity type>-<object_id>
|
||||
return f"{formatted_mac}-{_TYPE_TO_NAME[type(entity_info)]}-{entity_info.object_id}"
|
||||
|
||||
|
||||
def message_types_to_names(msg_types: Iterable[type[message.Message]]) -> str:
|
||||
return ", ".join(t.__name__ for t in msg_types)
|
||||
|
@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import (
|
||||
VoiceAssistantRequest,
|
||||
VoiceAssistantResponse,
|
||||
)
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionError,
|
||||
@ -955,9 +955,63 @@ async def test_bluetooth_pair_connection_drops(
|
||||
address=1234, connected=False, error=13
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
message = (
|
||||
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||
" for BluetoothDevicePairingResponse: Invalid attribute length"
|
||||
)
|
||||
with pytest.raises(
|
||||
APIConnectionError,
|
||||
match="Peripheral changed connections status while pairing: 13",
|
||||
BluetoothConnectionDroppedError,
|
||||
match=message,
|
||||
):
|
||||
await pair_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_unpair_connection_drops(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test connection drop during bluetooth_device_unpair."""
|
||||
client, connection, transport, protocol = api_client
|
||||
pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, error=13
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
message = (
|
||||
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||
" for BluetoothDeviceUnpairingResponse: Invalid attribute length"
|
||||
)
|
||||
with pytest.raises(
|
||||
BluetoothConnectionDroppedError,
|
||||
match=message,
|
||||
):
|
||||
await pair_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_clear_cache_connection_drops(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test connection drop during bluetooth_device_clear_cache."""
|
||||
client, connection, transport, protocol = api_client
|
||||
pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, error=13
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
message = (
|
||||
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||
" for BluetoothDeviceClearCacheResponse: Invalid attribute length"
|
||||
)
|
||||
with pytest.raises(
|
||||
BluetoothConnectionDroppedError,
|
||||
match=message,
|
||||
):
|
||||
await pair_task
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user