Fix races in bluetooth device connect (#740)

This commit is contained in:
J. Nick Koston 2023-11-26 15:21:54 -06:00 committed by GitHub
parent 108074a2dc
commit 150ce726da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 233 additions and 32 deletions

View File

@ -72,8 +72,10 @@ from .api_pb2 import ( # type: ignore
VoiceAssistantResponse,
)
from .client_callbacks import (
handle_timeout,
on_ble_raw_advertisement_response,
on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response,
on_bluetooth_gatt_notify_data_response,
on_bluetooth_le_advertising_response,
on_home_assistant_service_response,
@ -528,28 +530,6 @@ class APIClient:
(BluetoothConnectionsFreeResponse,),
)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if fut.done():
return
fut.set_exception(asyncio.TimeoutError)
def _on_bluetooth_device_connection_response(
self,
connect_future: asyncio.Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
if address == msg.address:
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)
async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches
self,
address: int,
@ -584,7 +564,7 @@ class APIClient:
address_type=address_type or 0,
),
partial(
self._on_bluetooth_device_connection_response,
on_bluetooth_device_connection_response,
connect_future,
address,
on_bluetooth_connection_state,
@ -594,7 +574,7 @@ class APIClient:
loop = self._loop
timeout_handle = loop.call_at(
loop.time() + timeout, self._handle_timeout, connect_future
loop.time() + timeout, handle_timeout, connect_future
)
timeout_expired = False
connect_ok = False
@ -602,6 +582,11 @@ class APIClient:
await connect_future
connect_ok = True
except asyncio.TimeoutError as err:
# If the timeout expires, make sure
# to unsub before calling _bluetooth_device_disconnect_guard_timeout
# so that the disconnect message is not propagated back to the caller
# since we are going to raise a TimeoutAPIError.
unsub()
timeout_expired = True
# Disconnect before raising the exception to ensure
# the slot is recovered before the timeout is raised
@ -620,14 +605,8 @@ class APIClient:
f" after {disconnect_timeout}s"
) from err
finally:
if not connect_ok:
try:
unsub()
except (KeyError, ValueError):
_LOGGER.warning(
"%s: Bluetooth device connection canceled but already unsubscribed",
to_human_readable_address(address),
)
if not connect_ok and not timeout_expired:
unsub()
if not timeout_expired:
timeout_handle.cancel()

View File

@ -8,3 +8,5 @@ cdef object CameraImageResponse, CameraState
cdef object HomeassistantServiceCall
cdef object BluetoothLEAdvertisement
cdef object asyncio_TimeoutError

View File

@ -1,11 +1,14 @@
from __future__ import annotations
from asyncio import Future
from asyncio import TimeoutError as asyncio_TimeoutError
from typing import TYPE_CHECKING, Callable
from google.protobuf import message
from .api_pb2 import ( # type: ignore
BluetoothConnectionsFreeResponse,
BluetoothDeviceConnectionResponse,
BluetoothGATTNotifyDataResponse,
BluetoothLEAdvertisementResponse,
BluetoothLERawAdvertisement,
@ -93,3 +96,25 @@ def on_subscribe_home_assistant_state_response(
msg: SubscribeHomeAssistantStateResponse,
) -> None:
on_state_sub(msg.entity_id, msg.attribute)
def handle_timeout(fut: Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def on_bluetooth_device_connection_response(
connect_future: Future[None],
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
msg: BluetoothDeviceConnectionResponse,
) -> None:
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
if address == msg.address:
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
# Resolve on ANY connection state since we do not want
# to wait the whole timeout if the device disconnects
# or we get an error.
if not connect_future.done():
connect_future.set_result(None)

View File

@ -17,6 +17,7 @@ from aioesphomeapi.api_pb2 import (
BluetoothDeviceClearCacheResponse,
BluetoothDeviceConnectionResponse,
BluetoothDevicePairingResponse,
BluetoothDeviceRequest,
BluetoothDeviceUnpairingResponse,
BluetoothGATTErrorResponse,
BluetoothGATTGetServicesDoneResponse,
@ -68,10 +69,12 @@ from aioesphomeapi.model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
BluetoothDeviceRequestType,
)
from aioesphomeapi.model import BluetoothGATTService as BluetoothGATTServiceModel
from aioesphomeapi.model import (
BluetoothLEAdvertisement,
BluetoothProxyFeature,
CameraState,
ClimateFanMode,
ClimateMode,
@ -1455,3 +1458,195 @@ async def test_force_disconnect(
assert connection.is_connected is False
await client.disconnect(force=False)
assert connection.is_connected is False
@pytest.mark.asyncio
@pytest.mark.parametrize(
("has_cache", "feature_flags", "method"),
[
(False, BluetoothProxyFeature(0), BluetoothDeviceRequestType.CONNECT),
(
False,
BluetoothProxyFeature.REMOTE_CACHING,
BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE,
),
(
True,
BluetoothProxyFeature.REMOTE_CACHING,
BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE,
),
],
)
async def test_bluetooth_device_connect(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
has_cache: bool,
feature_flags: BluetoothProxyFeature,
method: BluetoothDeviceRequestType,
) -> None:
"""Test bluetooth_device_connect."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=1,
feature_flags=feature_flags,
has_cache=has_cache,
disconnect_timeout=1,
address_type=1,
)
)
await asyncio.sleep(0)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=True, mtu=23, error=0
)
mock_data_received(protocol, generate_plaintext_packet(response))
cancel = await connect_task
assert states == [(True, 23, 0)]
transport.write.assert_called_once_with(
generate_plaintext_packet(
BluetoothDeviceRequest(
address=1234,
request_type=method,
has_address_type=True,
address_type=1,
),
)
)
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=7
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert states == [(True, 23, 0), (False, 23, 7)]
cancel()
# After cancel, no more messages should called back
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert states == [(True, 23, 0), (False, 23, 7)]
@pytest.mark.asyncio
async def test_bluetooth_device_connect_and_disconnect_times_out(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect and disconnect times out."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=0,
feature_flags=0,
has_cache=True,
disconnect_timeout=0,
address_type=1,
)
)
with pytest.raises(TimeoutAPIError):
await connect_task
assert states == []
@pytest.mark.asyncio
async def test_bluetooth_device_connect_times_out_disconnect_ok(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect and disconnect times out."""
client, connection, transport, protocol = api_client
states = []
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=0,
feature_flags=0,
has_cache=True,
disconnect_timeout=1,
address_type=1,
)
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
# Now that we timed out, the disconnect
# request should be written
assert len(transport.write.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
mock_data_received(protocol, generate_plaintext_packet(response))
with pytest.raises(TimeoutAPIError):
await connect_task
assert states == []
@pytest.mark.asyncio
async def test_bluetooth_device_connect_cancelled(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_device_connect handles cancellation."""
client, connection, transport, protocol = api_client
states = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
connect_task = asyncio.create_task(
client.bluetooth_device_connect(
1234,
on_bluetooth_connection_state,
timeout=10,
feature_flags=0,
has_cache=True,
disconnect_timeout=10,
address_type=1,
)
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task
assert states == []
handlers_after = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
# Make sure we do not leak message handlers
assert handlers_after == handlers_before