From 150ce726da00d92f57cf57c7456de54a7b96024e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 26 Nov 2023 15:21:54 -0600 Subject: [PATCH] Fix races in bluetooth device connect (#740) --- aioesphomeapi/client.py | 43 ++----- aioesphomeapi/client_callbacks.pxd | 2 + aioesphomeapi/client_callbacks.py | 25 ++++ tests/test_client.py | 195 +++++++++++++++++++++++++++++ 4 files changed, 233 insertions(+), 32 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 9d233ad..f76657e 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -72,8 +72,10 @@ from .api_pb2 import ( # type: ignore VoiceAssistantResponse, ) from .client_callbacks import ( + handle_timeout, on_ble_raw_advertisement_response, on_bluetooth_connections_free_response, + on_bluetooth_device_connection_response, on_bluetooth_gatt_notify_data_response, on_bluetooth_le_advertising_response, on_home_assistant_service_response, @@ -528,28 +530,6 @@ class APIClient: (BluetoothConnectionsFreeResponse,), ) - def _handle_timeout(self, fut: asyncio.Future[None]) -> None: - """Handle a timeout.""" - if fut.done(): - return - fut.set_exception(asyncio.TimeoutError) - - def _on_bluetooth_device_connection_response( - self, - connect_future: asyncio.Future[None], - address: int, - on_bluetooth_connection_state: Callable[[bool, int, int], None], - msg: BluetoothDeviceConnectionResponse, - ) -> None: - """Handle a BluetoothDeviceConnectionResponse message.""" "" - if address == msg.address: - on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error) - # Resolve on ANY connection state since we do not want - # to wait the whole timeout if the device disconnects - # or we get an error. - if not connect_future.done(): - connect_future.set_result(None) - async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches self, address: int, @@ -584,7 +564,7 @@ class APIClient: address_type=address_type or 0, ), partial( - self._on_bluetooth_device_connection_response, + on_bluetooth_device_connection_response, connect_future, address, on_bluetooth_connection_state, @@ -594,7 +574,7 @@ class APIClient: loop = self._loop timeout_handle = loop.call_at( - loop.time() + timeout, self._handle_timeout, connect_future + loop.time() + timeout, handle_timeout, connect_future ) timeout_expired = False connect_ok = False @@ -602,6 +582,11 @@ class APIClient: await connect_future connect_ok = True except asyncio.TimeoutError as err: + # If the timeout expires, make sure + # to unsub before calling _bluetooth_device_disconnect_guard_timeout + # so that the disconnect message is not propagated back to the caller + # since we are going to raise a TimeoutAPIError. + unsub() timeout_expired = True # Disconnect before raising the exception to ensure # the slot is recovered before the timeout is raised @@ -620,14 +605,8 @@ class APIClient: f" after {disconnect_timeout}s" ) from err finally: - if not connect_ok: - try: - unsub() - except (KeyError, ValueError): - _LOGGER.warning( - "%s: Bluetooth device connection canceled but already unsubscribed", - to_human_readable_address(address), - ) + if not connect_ok and not timeout_expired: + unsub() if not timeout_expired: timeout_handle.cancel() diff --git a/aioesphomeapi/client_callbacks.pxd b/aioesphomeapi/client_callbacks.pxd index 2611a2a..8fad3d9 100644 --- a/aioesphomeapi/client_callbacks.pxd +++ b/aioesphomeapi/client_callbacks.pxd @@ -8,3 +8,5 @@ cdef object CameraImageResponse, CameraState cdef object HomeassistantServiceCall cdef object BluetoothLEAdvertisement + +cdef object asyncio_TimeoutError diff --git a/aioesphomeapi/client_callbacks.py b/aioesphomeapi/client_callbacks.py index 483d8f4..c846c86 100644 --- a/aioesphomeapi/client_callbacks.py +++ b/aioesphomeapi/client_callbacks.py @@ -1,11 +1,14 @@ from __future__ import annotations +from asyncio import Future +from asyncio import TimeoutError as asyncio_TimeoutError from typing import TYPE_CHECKING, Callable from google.protobuf import message from .api_pb2 import ( # type: ignore BluetoothConnectionsFreeResponse, + BluetoothDeviceConnectionResponse, BluetoothGATTNotifyDataResponse, BluetoothLEAdvertisementResponse, BluetoothLERawAdvertisement, @@ -93,3 +96,25 @@ def on_subscribe_home_assistant_state_response( msg: SubscribeHomeAssistantStateResponse, ) -> None: on_state_sub(msg.entity_id, msg.attribute) + + +def handle_timeout(fut: Future[None]) -> None: + """Handle a timeout.""" + if not fut.done(): + fut.set_exception(asyncio_TimeoutError) + + +def on_bluetooth_device_connection_response( + connect_future: Future[None], + address: int, + on_bluetooth_connection_state: Callable[[bool, int, int], None], + msg: BluetoothDeviceConnectionResponse, +) -> None: + """Handle a BluetoothDeviceConnectionResponse message.""" "" + if address == msg.address: + on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error) + # Resolve on ANY connection state since we do not want + # to wait the whole timeout if the device disconnects + # or we get an error. + if not connect_future.done(): + connect_future.set_result(None) diff --git a/tests/test_client.py b/tests/test_client.py index e567432..9f26873 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,6 +17,7 @@ from aioesphomeapi.api_pb2 import ( BluetoothDeviceClearCacheResponse, BluetoothDeviceConnectionResponse, BluetoothDevicePairingResponse, + BluetoothDeviceRequest, BluetoothDeviceUnpairingResponse, BluetoothGATTErrorResponse, BluetoothGATTGetServicesDoneResponse, @@ -68,10 +69,12 @@ from aioesphomeapi.model import ( APIVersion, BinarySensorInfo, BinarySensorState, + BluetoothDeviceRequestType, ) from aioesphomeapi.model import BluetoothGATTService as BluetoothGATTServiceModel from aioesphomeapi.model import ( BluetoothLEAdvertisement, + BluetoothProxyFeature, CameraState, ClimateFanMode, ClimateMode, @@ -1455,3 +1458,195 @@ async def test_force_disconnect( assert connection.is_connected is False await client.disconnect(force=False) assert connection.is_connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("has_cache", "feature_flags", "method"), + [ + (False, BluetoothProxyFeature(0), BluetoothDeviceRequestType.CONNECT), + ( + False, + BluetoothProxyFeature.REMOTE_CACHING, + BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE, + ), + ( + True, + BluetoothProxyFeature.REMOTE_CACHING, + BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE, + ), + ], +) +async def test_bluetooth_device_connect( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + has_cache: bool, + feature_flags: BluetoothProxyFeature, + method: BluetoothDeviceRequestType, +) -> None: + """Test bluetooth_device_connect.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=1, + feature_flags=feature_flags, + has_cache=has_cache, + disconnect_timeout=1, + address_type=1, + ) + ) + await asyncio.sleep(0) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=True, mtu=23, error=0 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + + cancel = await connect_task + assert states == [(True, 23, 0)] + transport.write.assert_called_once_with( + generate_plaintext_packet( + BluetoothDeviceRequest( + address=1234, + request_type=method, + has_address_type=True, + address_type=1, + ), + ) + ) + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=7 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert states == [(True, 23, 0), (False, 23, 7)] + cancel() + + # After cancel, no more messages should called back + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=8 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert states == [(True, 23, 0), (False, 23, 7)] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_and_disconnect_times_out( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect and disconnect times out.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=0, + feature_flags=0, + has_cache=True, + disconnect_timeout=0, + address_type=1, + ) + ) + with pytest.raises(TimeoutAPIError): + await connect_task + assert states == [] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_times_out_disconnect_ok( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect and disconnect times out.""" + client, connection, transport, protocol = api_client + states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=0, + feature_flags=0, + has_cache=True, + disconnect_timeout=1, + address_type=1, + ) + ) + await asyncio.sleep(0) + # The connect request should be written + assert len(transport.write.mock_calls) == 1 + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + # Now that we timed out, the disconnect + # request should be written + assert len(transport.write.mock_calls) == 2 + response: message.Message = BluetoothDeviceConnectionResponse( + address=1234, connected=False, mtu=23, error=8 + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + with pytest.raises(TimeoutAPIError): + await connect_task + assert states == [] + + +@pytest.mark.asyncio +async def test_bluetooth_device_connect_cancelled( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_device_connect handles cancellation.""" + client, connection, transport, protocol = api_client + states = [] + + handlers_before = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + timeout=10, + feature_flags=0, + has_cache=True, + disconnect_timeout=10, + address_type=1, + ) + ) + await asyncio.sleep(0) + # The connect request should be written + assert len(transport.write.mock_calls) == 1 + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + assert states == [] + + handlers_after = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + # Make sure we do not leak message handlers + assert handlers_after == handlers_before