diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index b96fe9d..76adb50 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -1,5 +1,6 @@ import asyncio import logging +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -15,7 +16,6 @@ from typing import ( cast, ) -import async_timeout from google.protobuf import message from .api_pb2 import ( # type: ignore @@ -207,12 +207,7 @@ ExecuteServiceDataType = Dict[ # pylint: disable=too-many-public-methods class APIClient: - __slots__ = ( - "_params", - "_connection", - "_cached_name", - "_background_tasks", - ) + __slots__ = ("_params", "_connection", "_cached_name", "_background_tasks", "_loop") def __init__( self, @@ -255,6 +250,7 @@ class APIClient: self._connection: Optional[APIConnection] = None self._cached_name: Optional[str] = None self._background_tasks: set[asyncio.Task[Any]] = set() + self._loop = asyncio.get_event_loop() @property def expected_name(self) -> Optional[str]: @@ -510,7 +506,7 @@ class APIClient: on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] assert self._connection is not None - self._connection.send_message_callback_response( + unsub_callback = self._connection.send_message_callback_response( SubscribeBluetoothLEAdvertisementsRequest(flags=0), _on_bluetooth_le_advertising_response, msg_types, @@ -518,9 +514,7 @@ class APIClient: def unsub() -> None: if self._connection is not None: - self._connection.remove_message_callback( - _on_bluetooth_le_advertising_response, msg_types - ) + unsub_callback() self._connection.send_message( UnsubscribeBluetoothLEAdvertisementsRequest() ) @@ -535,7 +529,7 @@ class APIClient: assert self._connection is not None on_msg = make_ble_raw_advertisement_processor(on_advertisements) - self._connection.send_message_callback_response( + unsub_callback = self._connection.send_message_callback_response( SubscribeBluetoothLEAdvertisementsRequest( flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS ), @@ -545,7 +539,7 @@ class APIClient: def unsub() -> None: if self._connection is not None: - self._connection.remove_message_callback(on_msg, msg_types) + unsub_callback() self._connection.send_message( UnsubscribeBluetoothLEAdvertisementsRequest() ) @@ -565,21 +559,36 @@ class APIClient: on_bluetooth_connections_free_update(resp.free, resp.limit) assert self._connection is not None - self._connection.send_message_callback_response( + return self._connection.send_message_callback_response( SubscribeBluetoothConnectionsFreeRequest(), _on_bluetooth_connections_free_response, msg_types, ) - def unsub() -> None: - if self._connection is not None: - self._connection.remove_message_callback( - _on_bluetooth_connections_free_response, msg_types - ) + def _handle_timeout(self, fut: asyncio.Future[None]) -> None: + """Handle a timeout.""" + if fut.done(): + return + fut.set_exception(asyncio.TimeoutError()) - return unsub + def _on_bluetooth_device_connection_response( + self, + connect_future: asyncio.Future[None], + address: int, + on_bluetooth_connection_state: Callable[[bool, int, int], None], + msg: BluetoothDeviceConnectionResponse, + ) -> None: + """Handle a BluetoothDeviceConnectionResponse message.""" "" + resp = BluetoothDeviceConnection.from_pb(msg) + if address == resp.address: + on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error) + # Resolve on ANY connection state since we do not want + # to wait the whole timeout if the device disconnects + # or we get an error. + if not connect_future.done(): + connect_future.set_result(None) - async def bluetooth_device_connect( # pylint: disable=too-many-locals + async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches self, address: int, on_bluetooth_connection_state: Callable[[bool, int, int], None], @@ -591,62 +600,55 @@ class APIClient: ) -> Callable[[], None]: self._check_authenticated() msg_types = (BluetoothDeviceConnectionResponse,) - - event = asyncio.Event() - - def _on_bluetooth_device_connection_response( - msg: BluetoothDeviceConnectionResponse, - ) -> None: - resp = BluetoothDeviceConnection.from_pb(msg) - if address == resp.address: - on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error) - # Resolve on ANY connection state since we do not want - # to wait the whole timeout if the device disconnects - # or we get an error. - event.set() + debug = _LOGGER.isEnabledFor(logging.DEBUG) + connect_future: asyncio.Future[None] = self._loop.create_future() assert self._connection is not None if has_cache: # REMOTE_CACHING feature with cache: requestor has services and mtu cached - _LOGGER.debug("%s: Using connection version 3 with cache", address) request_type = BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE elif feature_flags & BluetoothProxyFeature.REMOTE_CACHING: # REMOTE_CACHING feature without cache: esp will wipe the service list after sending to save memory - _LOGGER.debug("%s: Using connection version 3 without cache", address) request_type = BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE else: - # Device doesnt support REMOTE_CACHING feature: esp will hold the service list in memory for the duration + # Device does not support REMOTE_CACHING feature: esp will hold the service list in memory for the duration # of the connection. This can crash the esp if the service list is too large. - _LOGGER.debug("%s: Using connection version 1", address) request_type = BluetoothDeviceRequestType.CONNECT - self._connection.send_message_callback_response( + if debug: + _LOGGER.debug("%s: Using connection version %s", address, request_type) + + unsub = self._connection.send_message_callback_response( BluetoothDeviceRequest( address=address, request_type=request_type, has_address_type=address_type is not None, address_type=address_type or 0, ), - _on_bluetooth_device_connection_response, + partial( + self._on_bluetooth_device_connection_response, + connect_future, + address, + on_bluetooth_connection_state, + ), msg_types, ) - def unsub() -> None: - if self._connection is not None: - self._connection.remove_message_callback( - _on_bluetooth_device_connection_response, msg_types - ) - + timeout_handle = self._loop.call_later( + timeout, self._handle_timeout, connect_future + ) try: try: - async with async_timeout.timeout(timeout): - await event.wait() + await connect_future except asyncio.TimeoutError as err: # Disconnect before raising the exception to ensure # the slot is recovered before the timeout is raised # to avoid race were we run out even though we have a slot. addr = to_human_readable_address(address) - _LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr) + if debug: + _LOGGER.debug( + "%s: Connecting timed out, waiting for disconnect", addr + ) disconnect_timed_out = False try: await self.bluetooth_device_disconnect( @@ -654,9 +656,10 @@ class APIClient: ) except TimeoutAPIError: disconnect_timed_out = True - _LOGGER.debug( - "%s: Disconnect timed out: %s", addr, disconnect_timed_out - ) + if debug: + _LOGGER.debug( + "%s: Disconnect timed out: %s", addr, disconnect_timed_out + ) finally: try: unsub() @@ -680,6 +683,8 @@ class APIClient: addr, ) raise + finally: + timeout_handle.cancel() return unsub diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 0a10f1d..3b9d827 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -566,9 +566,9 @@ class APIConnection: message_handlers = self._message_handlers for msg_type in msg_types: message_handlers.setdefault(msg_type, set()).add(on_message) - return partial(self.remove_message_callback, on_message, msg_types) + return partial(self._remove_message_callback, on_message, msg_types) - def remove_message_callback( + def _remove_message_callback( self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]] ) -> None: """Remove a message callback.""" @@ -581,7 +581,7 @@ class APIConnection: send_msg: message.Message, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]], - ) -> None: + ) -> Callable[[], None]: """Send a message to the remote and register the given message handler.""" self.send_message(send_msg) # Since we do not return control to the event loop (no awaits) @@ -590,6 +590,7 @@ class APIConnection: # we register the handler after sending the message for msg_type in msg_types: self._message_handlers.setdefault(msg_type, set()).add(on_message) + return partial(self._remove_message_callback, on_message, msg_types) def _handle_timeout(self, fut: asyncio.Future[None]) -> None: """Handle a timeout."""