Speed up BLE connections (#482)

This commit is contained in:
J. Nick Koston 2023-07-15 11:16:44 -10:00 committed by GitHub
parent ed0a611994
commit 6aeea79884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 55 deletions

View File

@ -1,5 +1,6 @@
import asyncio
import logging
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
@ -15,7 +16,6 @@ from typing import (
cast,
)
import async_timeout
from google.protobuf import message
from .api_pb2 import ( # type: ignore
@ -207,12 +207,7 @@ ExecuteServiceDataType = Dict[
# pylint: disable=too-many-public-methods
class APIClient:
__slots__ = (
"_params",
"_connection",
"_cached_name",
"_background_tasks",
)
__slots__ = ("_params", "_connection", "_cached_name", "_background_tasks", "_loop")
def __init__(
self,
@ -255,6 +250,7 @@ class APIClient:
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
@property
def expected_name(self) -> Optional[str]:
@ -510,7 +506,7 @@ class APIClient:
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
assert self._connection is not None
self._connection.send_message_callback_response(
unsub_callback = self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(flags=0),
_on_bluetooth_le_advertising_response,
msg_types,
@ -518,9 +514,7 @@ class APIClient:
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(
_on_bluetooth_le_advertising_response, msg_types
)
unsub_callback()
self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest()
)
@ -535,7 +529,7 @@ class APIClient:
assert self._connection is not None
on_msg = make_ble_raw_advertisement_processor(on_advertisements)
self._connection.send_message_callback_response(
unsub_callback = self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(
flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS
),
@ -545,7 +539,7 @@ class APIClient:
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(on_msg, msg_types)
unsub_callback()
self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest()
)
@ -565,21 +559,36 @@ class APIClient:
on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None
self._connection.send_message_callback_response(
return self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(),
_on_bluetooth_connections_free_response,
msg_types,
)
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(
_on_bluetooth_connections_free_response, msg_types
)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if fut.done():
return
fut.set_exception(asyncio.TimeoutError())
return unsub
def _on_bluetooth_device_connection_response(
self,
connect_future: asyncio.Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
resp = BluetoothDeviceConnection.from_pb(msg)
if address == resp.address:
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)
async def bluetooth_device_connect( # pylint: disable=too-many-locals
async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches
self,
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
@ -591,62 +600,55 @@ class APIClient:
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,)
event = asyncio.Event()
def _on_bluetooth_device_connection_response(
msg: BluetoothDeviceConnectionResponse,
) -> None:
resp = BluetoothDeviceConnection.from_pb(msg)
if address == resp.address:
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
event.set()
debug = _LOGGER.isEnabledFor(logging.DEBUG)
connect_future: asyncio.Future[None] = self._loop.create_future()
assert self._connection is not None
if has_cache:
# REMOTE_CACHING feature with cache: requestor has services and mtu cached
_LOGGER.debug("%s: Using connection version 3 with cache", address)
request_type = BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE
elif feature_flags & BluetoothProxyFeature.REMOTE_CACHING:
# REMOTE_CACHING feature without cache: esp will wipe the service list after sending to save memory
_LOGGER.debug("%s: Using connection version 3 without cache", address)
request_type = BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE
else:
# Device doesnt support REMOTE_CACHING feature: esp will hold the service list in memory for the duration
# Device does not support REMOTE_CACHING feature: esp will hold the service list in memory for the duration
# of the connection. This can crash the esp if the service list is too large.
_LOGGER.debug("%s: Using connection version 1", address)
request_type = BluetoothDeviceRequestType.CONNECT
self._connection.send_message_callback_response(
if debug:
_LOGGER.debug("%s: Using connection version %s", address, request_type)
unsub = self._connection.send_message_callback_response(
BluetoothDeviceRequest(
address=address,
request_type=request_type,
has_address_type=address_type is not None,
address_type=address_type or 0,
),
_on_bluetooth_device_connection_response,
partial(
self._on_bluetooth_device_connection_response,
connect_future,
address,
on_bluetooth_connection_state,
),
msg_types,
)
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(
_on_bluetooth_device_connection_response, msg_types
)
timeout_handle = self._loop.call_later(
timeout, self._handle_timeout, connect_future
)
try:
try:
async with async_timeout.timeout(timeout):
await event.wait()
await connect_future
except asyncio.TimeoutError as err:
# Disconnect before raising the exception to ensure
# the slot is recovered before the timeout is raised
# to avoid race were we run out even though we have a slot.
addr = to_human_readable_address(address)
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
if debug:
_LOGGER.debug(
"%s: Connecting timed out, waiting for disconnect", addr
)
disconnect_timed_out = False
try:
await self.bluetooth_device_disconnect(
@ -654,9 +656,10 @@ class APIClient:
)
except TimeoutAPIError:
disconnect_timed_out = True
_LOGGER.debug(
"%s: Disconnect timed out: %s", addr, disconnect_timed_out
)
if debug:
_LOGGER.debug(
"%s: Disconnect timed out: %s", addr, disconnect_timed_out
)
finally:
try:
unsub()
@ -680,6 +683,8 @@ class APIClient:
addr,
)
raise
finally:
timeout_handle.cancel()
return unsub

View File

@ -566,9 +566,9 @@ class APIConnection:
message_handlers = self._message_handlers
for msg_type in msg_types:
message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self.remove_message_callback, on_message, msg_types)
return partial(self._remove_message_callback, on_message, msg_types)
def remove_message_callback(
def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> None:
"""Remove a message callback."""
@ -581,7 +581,7 @@ class APIConnection:
send_msg: message.Message,
on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]],
) -> None:
) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler."""
self.send_message(send_msg)
# Since we do not return control to the event loop (no awaits)
@ -590,6 +590,7 @@ class APIConnection:
# we register the handler after sending the message
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self._remove_message_callback, on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""