Speed up BLE connections (#482)

This commit is contained in:
J. Nick Koston 2023-07-15 11:16:44 -10:00 committed by GitHub
parent ed0a611994
commit 6aeea79884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 55 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -15,7 +16,6 @@ from typing import (
cast, cast,
) )
import async_timeout
from google.protobuf import message from google.protobuf import message
from .api_pb2 import ( # type: ignore from .api_pb2 import ( # type: ignore
@ -207,12 +207,7 @@ ExecuteServiceDataType = Dict[
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class APIClient: class APIClient:
__slots__ = ( __slots__ = ("_params", "_connection", "_cached_name", "_background_tasks", "_loop")
"_params",
"_connection",
"_cached_name",
"_background_tasks",
)
def __init__( def __init__(
self, self,
@ -255,6 +250,7 @@ class APIClient:
self._connection: Optional[APIConnection] = None self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None self._cached_name: Optional[str] = None
self._background_tasks: set[asyncio.Task[Any]] = set() self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
@property @property
def expected_name(self) -> Optional[str]: def expected_name(self) -> Optional[str]:
@ -510,7 +506,7 @@ class APIClient:
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
assert self._connection is not None assert self._connection is not None
self._connection.send_message_callback_response( unsub_callback = self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(flags=0), SubscribeBluetoothLEAdvertisementsRequest(flags=0),
_on_bluetooth_le_advertising_response, _on_bluetooth_le_advertising_response,
msg_types, msg_types,
@ -518,9 +514,7 @@ class APIClient:
def unsub() -> None: def unsub() -> None:
if self._connection is not None: if self._connection is not None:
self._connection.remove_message_callback( unsub_callback()
_on_bluetooth_le_advertising_response, msg_types
)
self._connection.send_message( self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest() UnsubscribeBluetoothLEAdvertisementsRequest()
) )
@ -535,7 +529,7 @@ class APIClient:
assert self._connection is not None assert self._connection is not None
on_msg = make_ble_raw_advertisement_processor(on_advertisements) on_msg = make_ble_raw_advertisement_processor(on_advertisements)
self._connection.send_message_callback_response( unsub_callback = self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest( SubscribeBluetoothLEAdvertisementsRequest(
flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS
), ),
@ -545,7 +539,7 @@ class APIClient:
def unsub() -> None: def unsub() -> None:
if self._connection is not None: if self._connection is not None:
self._connection.remove_message_callback(on_msg, msg_types) unsub_callback()
self._connection.send_message( self._connection.send_message(
UnsubscribeBluetoothLEAdvertisementsRequest() UnsubscribeBluetoothLEAdvertisementsRequest()
) )
@ -565,21 +559,36 @@ class APIClient:
on_bluetooth_connections_free_update(resp.free, resp.limit) on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None assert self._connection is not None
self._connection.send_message_callback_response( return self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(), SubscribeBluetoothConnectionsFreeRequest(),
_on_bluetooth_connections_free_response, _on_bluetooth_connections_free_response,
msg_types, msg_types,
) )
def unsub() -> None: def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
if self._connection is not None: """Handle a timeout."""
self._connection.remove_message_callback( if fut.done():
_on_bluetooth_connections_free_response, msg_types return
) fut.set_exception(asyncio.TimeoutError())
return unsub def _on_bluetooth_device_connection_response(
self,
connect_future: asyncio.Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
resp = BluetoothDeviceConnection.from_pb(msg)
if address == resp.address:
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)
async def bluetooth_device_connect( # pylint: disable=too-many-locals async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches
self, self,
address: int, address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None], on_bluetooth_connection_state: Callable[[bool, int, int], None],
@ -591,62 +600,55 @@ class APIClient:
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,) msg_types = (BluetoothDeviceConnectionResponse,)
debug = _LOGGER.isEnabledFor(logging.DEBUG)
event = asyncio.Event() connect_future: asyncio.Future[None] = self._loop.create_future()
def _on_bluetooth_device_connection_response(
msg: BluetoothDeviceConnectionResponse,
) -> None:
resp = BluetoothDeviceConnection.from_pb(msg)
if address == resp.address:
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
event.set()
assert self._connection is not None assert self._connection is not None
if has_cache: if has_cache:
# REMOTE_CACHING feature with cache: requestor has services and mtu cached # REMOTE_CACHING feature with cache: requestor has services and mtu cached
_LOGGER.debug("%s: Using connection version 3 with cache", address)
request_type = BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE request_type = BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE
elif feature_flags & BluetoothProxyFeature.REMOTE_CACHING: elif feature_flags & BluetoothProxyFeature.REMOTE_CACHING:
# REMOTE_CACHING feature without cache: esp will wipe the service list after sending to save memory # REMOTE_CACHING feature without cache: esp will wipe the service list after sending to save memory
_LOGGER.debug("%s: Using connection version 3 without cache", address)
request_type = BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE request_type = BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE
else: else:
# Device doesnt support REMOTE_CACHING feature: esp will hold the service list in memory for the duration # Device does not support REMOTE_CACHING feature: esp will hold the service list in memory for the duration
# of the connection. This can crash the esp if the service list is too large. # of the connection. This can crash the esp if the service list is too large.
_LOGGER.debug("%s: Using connection version 1", address)
request_type = BluetoothDeviceRequestType.CONNECT request_type = BluetoothDeviceRequestType.CONNECT
self._connection.send_message_callback_response( if debug:
_LOGGER.debug("%s: Using connection version %s", address, request_type)
unsub = self._connection.send_message_callback_response(
BluetoothDeviceRequest( BluetoothDeviceRequest(
address=address, address=address,
request_type=request_type, request_type=request_type,
has_address_type=address_type is not None, has_address_type=address_type is not None,
address_type=address_type or 0, address_type=address_type or 0,
), ),
_on_bluetooth_device_connection_response, partial(
self._on_bluetooth_device_connection_response,
connect_future,
address,
on_bluetooth_connection_state,
),
msg_types, msg_types,
) )
def unsub() -> None: timeout_handle = self._loop.call_later(
if self._connection is not None: timeout, self._handle_timeout, connect_future
self._connection.remove_message_callback( )
_on_bluetooth_device_connection_response, msg_types
)
try: try:
try: try:
async with async_timeout.timeout(timeout): await connect_future
await event.wait()
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
# Disconnect before raising the exception to ensure # Disconnect before raising the exception to ensure
# the slot is recovered before the timeout is raised # the slot is recovered before the timeout is raised
# to avoid race were we run out even though we have a slot. # to avoid race were we run out even though we have a slot.
addr = to_human_readable_address(address) addr = to_human_readable_address(address)
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr) if debug:
_LOGGER.debug(
"%s: Connecting timed out, waiting for disconnect", addr
)
disconnect_timed_out = False disconnect_timed_out = False
try: try:
await self.bluetooth_device_disconnect( await self.bluetooth_device_disconnect(
@ -654,9 +656,10 @@ class APIClient:
) )
except TimeoutAPIError: except TimeoutAPIError:
disconnect_timed_out = True disconnect_timed_out = True
_LOGGER.debug( if debug:
"%s: Disconnect timed out: %s", addr, disconnect_timed_out _LOGGER.debug(
) "%s: Disconnect timed out: %s", addr, disconnect_timed_out
)
finally: finally:
try: try:
unsub() unsub()
@ -680,6 +683,8 @@ class APIClient:
addr, addr,
) )
raise raise
finally:
timeout_handle.cancel()
return unsub return unsub

View File

@ -566,9 +566,9 @@ class APIConnection:
message_handlers = self._message_handlers message_handlers = self._message_handlers
for msg_type in msg_types: for msg_type in msg_types:
message_handlers.setdefault(msg_type, set()).add(on_message) message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self.remove_message_callback, on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types)
def remove_message_callback( def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]] self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> None: ) -> None:
"""Remove a message callback.""" """Remove a message callback."""
@ -581,7 +581,7 @@ class APIConnection:
send_msg: message.Message, send_msg: message.Message,
on_message: Callable[[Any], None], on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]], msg_types: Iterable[Type[Any]],
) -> None: ) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler.""" """Send a message to the remote and register the given message handler."""
self.send_message(send_msg) self.send_message(send_msg)
# Since we do not return control to the event loop (no awaits) # Since we do not return control to the event loop (no awaits)
@ -590,6 +590,7 @@ class APIConnection:
# we register the handler after sending the message # we register the handler after sending the message
for msg_type in msg_types: for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, set()).add(on_message) self._message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self._remove_message_callback, on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None: def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout.""" """Handle a timeout."""