mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Handle Bluetooth connection drops in more places (#766)
This commit is contained in:
parent
72a8f70bcd
commit
3e920df478
@ -11,6 +11,7 @@ from .core import (
|
|||||||
MESSAGE_TYPE_TO_PROTO,
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
BadNameAPIError,
|
BadNameAPIError,
|
||||||
|
BluetoothConnectionDroppedError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
|
@ -83,9 +83,11 @@ from .client_callbacks import (
|
|||||||
from .connection import APIConnection, ConnectionParams, handle_timeout
|
from .connection import APIConnection, ConnectionParams, handle_timeout
|
||||||
from .core import (
|
from .core import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
|
BluetoothConnectionDroppedError,
|
||||||
BluetoothGATTAPIError,
|
BluetoothGATTAPIError,
|
||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
to_human_readable_address,
|
to_human_readable_address,
|
||||||
|
to_human_readable_gatt_error,
|
||||||
)
|
)
|
||||||
from .model import (
|
from .model import (
|
||||||
AlarmControlPanelCommand,
|
AlarmControlPanelCommand,
|
||||||
@ -118,7 +120,11 @@ from .model import (
|
|||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
)
|
)
|
||||||
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
|
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
|
||||||
from .model import VoiceAssistantCommand, VoiceAssistantEventType
|
from .model import (
|
||||||
|
VoiceAssistantCommand,
|
||||||
|
VoiceAssistantEventType,
|
||||||
|
message_types_to_names,
|
||||||
|
)
|
||||||
from .model_conversions import (
|
from .model_conversions import (
|
||||||
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
|
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
|
||||||
SUBSCRIBE_STATES_RESPONSE_TYPES,
|
SUBSCRIBE_STATES_RESPONSE_TYPES,
|
||||||
@ -632,26 +638,11 @@ class APIClient:
|
|||||||
async def bluetooth_device_pair(
|
async def bluetooth_device_pair(
|
||||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||||
) -> BluetoothDevicePairing:
|
) -> BluetoothDevicePairing:
|
||||||
def predicate_func(
|
|
||||||
msg: BluetoothDevicePairingResponse | BluetoothDeviceConnectionResponse,
|
|
||||||
) -> bool:
|
|
||||||
if msg.address != address:
|
|
||||||
return False
|
|
||||||
if isinstance(msg, BluetoothDeviceConnectionResponse):
|
|
||||||
raise APIConnectionError(
|
|
||||||
f"Peripheral changed connections status while pairing: {msg.error}"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return BluetoothDevicePairing.from_pb(
|
return BluetoothDevicePairing.from_pb(
|
||||||
await self._bluetooth_device_request(
|
await self._bluetooth_device_request_watch_connection(
|
||||||
address,
|
address,
|
||||||
BluetoothDeviceRequestType.PAIR,
|
BluetoothDeviceRequestType.PAIR,
|
||||||
predicate_func,
|
(BluetoothDevicePairingResponse,),
|
||||||
(
|
|
||||||
BluetoothDevicePairingResponse,
|
|
||||||
BluetoothDeviceConnectionResponse,
|
|
||||||
),
|
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -660,10 +651,9 @@ class APIClient:
|
|||||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||||
) -> BluetoothDeviceUnpairing:
|
) -> BluetoothDeviceUnpairing:
|
||||||
return BluetoothDeviceUnpairing.from_pb(
|
return BluetoothDeviceUnpairing.from_pb(
|
||||||
await self._bluetooth_device_request(
|
await self._bluetooth_device_request_watch_connection(
|
||||||
address,
|
address,
|
||||||
BluetoothDeviceRequestType.UNPAIR,
|
BluetoothDeviceRequestType.UNPAIR,
|
||||||
lambda msg: msg.address == address,
|
|
||||||
(BluetoothDeviceUnpairingResponse,),
|
(BluetoothDeviceUnpairingResponse,),
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
@ -673,10 +663,9 @@ class APIClient:
|
|||||||
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
self, address: int, timeout: float = DEFAULT_BLE_TIMEOUT
|
||||||
) -> BluetoothDeviceClearCache:
|
) -> BluetoothDeviceClearCache:
|
||||||
return BluetoothDeviceClearCache.from_pb(
|
return BluetoothDeviceClearCache.from_pb(
|
||||||
await self._bluetooth_device_request(
|
await self._bluetooth_device_request_watch_connection(
|
||||||
address,
|
address,
|
||||||
BluetoothDeviceRequestType.CLEAR_CACHE,
|
BluetoothDeviceRequestType.CLEAR_CACHE,
|
||||||
lambda msg: msg.address == address,
|
|
||||||
(BluetoothDeviceClearCacheResponse,),
|
(BluetoothDeviceClearCacheResponse,),
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
@ -694,6 +683,43 @@ class APIClient:
|
|||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _bluetooth_device_request_watch_connection(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
request_type: BluetoothDeviceRequestType,
|
||||||
|
msg_types: tuple[type[message.Message], ...],
|
||||||
|
timeout: float,
|
||||||
|
) -> message.Message:
|
||||||
|
"""Send a BluetoothDeviceRequest watch for the connection state to change."""
|
||||||
|
response = await self._bluetooth_device_request(
|
||||||
|
address,
|
||||||
|
request_type,
|
||||||
|
lambda msg: msg.address == address,
|
||||||
|
(BluetoothDeviceConnectionResponse, *msg_types),
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
type(response) # pylint: disable=unidiomatic-typecheck
|
||||||
|
is BluetoothDeviceConnectionResponse
|
||||||
|
):
|
||||||
|
self._raise_for_ble_connection_change(address, response, msg_types)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _raise_for_ble_connection_change(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
response: BluetoothDeviceConnectionResponse,
|
||||||
|
msg_types: tuple[type[message.Message], ...],
|
||||||
|
) -> None:
|
||||||
|
"""Raise an exception if the connection status changed."""
|
||||||
|
response_names = message_types_to_names(msg_types)
|
||||||
|
human_readable_address = to_human_readable_address(address)
|
||||||
|
raise BluetoothConnectionDroppedError(
|
||||||
|
f"Peripheral {human_readable_address} changed connection status while waiting for "
|
||||||
|
f"{response_names}: {to_human_readable_gatt_error(response.error)} "
|
||||||
|
f"({response.error})"
|
||||||
|
)
|
||||||
|
|
||||||
async def _bluetooth_device_request(
|
async def _bluetooth_device_request(
|
||||||
self,
|
self,
|
||||||
address: int,
|
address: int,
|
||||||
@ -702,6 +728,7 @@ class APIClient:
|
|||||||
msg_types: tuple[type[message.Message], ...],
|
msg_types: tuple[type[message.Message], ...],
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> message.Message:
|
) -> message.Message:
|
||||||
|
"""Send a BluetoothDeviceRequest and wait for a response."""
|
||||||
[response] = await self._get_connection().send_messages_await_response_complex(
|
[response] = await self._get_connection().send_messages_await_response_complex(
|
||||||
(
|
(
|
||||||
BluetoothDeviceRequest(
|
BluetoothDeviceRequest(
|
||||||
@ -941,7 +968,6 @@ class APIClient:
|
|||||||
elif position == 0.0:
|
elif position == 0.0:
|
||||||
req.legacy_command = LegacyCoverCommand.CLOSE
|
req.legacy_command = LegacyCoverCommand.CLOSE
|
||||||
req.has_legacy_command = True
|
req.has_legacy_command = True
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def fan_command(
|
async def fan_command(
|
||||||
@ -969,7 +995,6 @@ class APIClient:
|
|||||||
if direction is not None:
|
if direction is not None:
|
||||||
req.has_direction = True
|
req.has_direction = True
|
||||||
req.direction = direction
|
req.direction = direction
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def light_command( # pylint: disable=too-many-branches
|
async def light_command( # pylint: disable=too-many-branches
|
||||||
@ -1027,7 +1052,6 @@ class APIClient:
|
|||||||
if effect is not None:
|
if effect is not None:
|
||||||
req.has_effect = True
|
req.has_effect = True
|
||||||
req.effect = effect
|
req.effect = effect
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def switch_command(self, key: int, state: bool) -> None:
|
async def switch_command(self, key: int, state: bool) -> None:
|
||||||
@ -1079,7 +1103,6 @@ class APIClient:
|
|||||||
if custom_preset is not None:
|
if custom_preset is not None:
|
||||||
req.has_custom_preset = True
|
req.has_custom_preset = True
|
||||||
req.custom_preset = custom_preset
|
req.custom_preset = custom_preset
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def number_command(self, key: int, state: float) -> None:
|
async def number_command(self, key: int, state: float) -> None:
|
||||||
@ -1109,7 +1132,6 @@ class APIClient:
|
|||||||
if duration is not None:
|
if duration is not None:
|
||||||
req.duration = duration
|
req.duration = duration
|
||||||
req.has_duration = True
|
req.has_duration = True
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def button_command(self, key: int) -> None:
|
async def button_command(self, key: int) -> None:
|
||||||
@ -1144,7 +1166,6 @@ class APIClient:
|
|||||||
if media_url is not None:
|
if media_url is not None:
|
||||||
req.media_url = media_url
|
req.media_url = media_url
|
||||||
req.has_media_url = True
|
req.has_media_url = True
|
||||||
|
|
||||||
self._get_connection().send_message(req)
|
self._get_connection().send_message(req)
|
||||||
|
|
||||||
async def text_command(self, key: int, state: str) -> None:
|
async def text_command(self, key: int, state: str) -> None:
|
||||||
|
@ -49,7 +49,7 @@ from .core import (
|
|||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
UnhandledAPIConnectionError,
|
UnhandledAPIConnectionError,
|
||||||
)
|
)
|
||||||
from .model import APIVersion
|
from .model import APIVersion, message_types_to_names
|
||||||
from .zeroconf import ZeroconfManager
|
from .zeroconf import ZeroconfManager
|
||||||
|
|
||||||
if sys.version_info[:2] < (3, 11):
|
if sys.version_info[:2] < (3, 11):
|
||||||
@ -758,7 +758,7 @@ class APIConnection:
|
|||||||
await fut
|
await fut
|
||||||
except asyncio_TimeoutError as err:
|
except asyncio_TimeoutError as err:
|
||||||
timeout_expired = True
|
timeout_expired = True
|
||||||
response_names = ", ".join(t.__name__ for t in msg_types)
|
response_names = message_types_to_names(msg_types)
|
||||||
raise TimeoutAPIError(
|
raise TimeoutAPIError(
|
||||||
f"Timeout waiting for {response_names} after {timeout}s"
|
f"Timeout waiting for {response_names} after {timeout}s"
|
||||||
) from err
|
) from err
|
||||||
|
@ -228,6 +228,10 @@ class UnhandledAPIConnectionError(APIConnectionError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BluetoothConnectionDroppedError(APIConnectionError):
|
||||||
|
"""Raised when a Bluetooth connection is dropped."""
|
||||||
|
|
||||||
|
|
||||||
def to_human_readable_address(address: int) -> str:
|
def to_human_readable_address(address: int) -> str:
|
||||||
"""Convert a MAC address to a human readable format."""
|
"""Convert a MAC address to a human readable format."""
|
||||||
return ":".join(TWO_CHAR.findall(f"{address:012X}"))
|
return ":".join(TWO_CHAR.findall(f"{address:012X}"))
|
||||||
|
@ -8,6 +8,8 @@ from functools import cache, lru_cache, partial
|
|||||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from google.protobuf import message
|
||||||
|
|
||||||
from .util import fix_float_single_double_conversion
|
from .util import fix_float_single_double_conversion
|
||||||
|
|
||||||
if sys.version_info[:2] < (3, 10):
|
if sys.version_info[:2] < (3, 10):
|
||||||
@ -1166,3 +1168,7 @@ def build_unique_id(formatted_mac: str, entity_info: EntityInfo) -> str:
|
|||||||
"""
|
"""
|
||||||
# <mac>-<entity type>-<object_id>
|
# <mac>-<entity type>-<object_id>
|
||||||
return f"{formatted_mac}-{_TYPE_TO_NAME[type(entity_info)]}-{entity_info.object_id}"
|
return f"{formatted_mac}-{_TYPE_TO_NAME[type(entity_info)]}-{entity_info.object_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def message_types_to_names(msg_types: Iterable[type[message.Message]]) -> str:
|
||||||
|
return ", ".join(t.__name__ for t in msg_types)
|
||||||
|
@ -66,7 +66,7 @@ from aioesphomeapi.api_pb2 import (
|
|||||||
VoiceAssistantRequest,
|
VoiceAssistantRequest,
|
||||||
VoiceAssistantResponse,
|
VoiceAssistantResponse,
|
||||||
)
|
)
|
||||||
from aioesphomeapi.client import APIClient
|
from aioesphomeapi.client import APIClient, BluetoothConnectionDroppedError
|
||||||
from aioesphomeapi.connection import APIConnection
|
from aioesphomeapi.connection import APIConnection
|
||||||
from aioesphomeapi.core import (
|
from aioesphomeapi.core import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
@ -955,9 +955,63 @@ async def test_bluetooth_pair_connection_drops(
|
|||||||
address=1234, connected=False, error=13
|
address=1234, connected=False, error=13
|
||||||
)
|
)
|
||||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
message = (
|
||||||
|
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||||
|
" for BluetoothDevicePairingResponse: Invalid attribute length"
|
||||||
|
)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
APIConnectionError,
|
BluetoothConnectionDroppedError,
|
||||||
match="Peripheral changed connections status while pairing: 13",
|
match=message,
|
||||||
|
):
|
||||||
|
await pair_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bluetooth_unpair_connection_drops(
|
||||||
|
api_client: tuple[
|
||||||
|
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test connection drop during bluetooth_device_unpair."""
|
||||||
|
client, connection, transport, protocol = api_client
|
||||||
|
pair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||||
|
address=1234, connected=False, error=13
|
||||||
|
)
|
||||||
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
message = (
|
||||||
|
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||||
|
" for BluetoothDeviceUnpairingResponse: Invalid attribute length"
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
BluetoothConnectionDroppedError,
|
||||||
|
match=message,
|
||||||
|
):
|
||||||
|
await pair_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bluetooth_clear_cache_connection_drops(
|
||||||
|
api_client: tuple[
|
||||||
|
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test connection drop during bluetooth_device_clear_cache."""
|
||||||
|
client, connection, transport, protocol = api_client
|
||||||
|
pair_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||||
|
address=1234, connected=False, error=13
|
||||||
|
)
|
||||||
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
message = (
|
||||||
|
"Peripheral 00:00:00:00:04:D2 changed connection status while waiting"
|
||||||
|
" for BluetoothDeviceClearCacheResponse: Invalid attribute length"
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
BluetoothConnectionDroppedError,
|
||||||
|
match=message,
|
||||||
):
|
):
|
||||||
await pair_task
|
await pair_task
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user