mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Fix races in bluetooth device connect (#740)
This commit is contained in:
parent
108074a2dc
commit
150ce726da
@ -72,8 +72,10 @@ from .api_pb2 import ( # type: ignore
|
||||
VoiceAssistantResponse,
|
||||
)
|
||||
from .client_callbacks import (
|
||||
handle_timeout,
|
||||
on_ble_raw_advertisement_response,
|
||||
on_bluetooth_connections_free_response,
|
||||
on_bluetooth_device_connection_response,
|
||||
on_bluetooth_gatt_notify_data_response,
|
||||
on_bluetooth_le_advertising_response,
|
||||
on_home_assistant_service_response,
|
||||
@ -528,28 +530,6 @@ class APIClient:
|
||||
(BluetoothConnectionsFreeResponse,),
|
||||
)
|
||||
|
||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_exception(asyncio.TimeoutError)
|
||||
|
||||
def _on_bluetooth_device_connection_response(
|
||||
self,
|
||||
connect_future: asyncio.Future[None],
|
||||
address: int,
|
||||
on_bluetooth_connection_state: Callable[[bool, int, int], None],
|
||||
msg: BluetoothDeviceConnectionResponse,
|
||||
) -> None:
|
||||
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
|
||||
if address == msg.address:
|
||||
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
|
||||
# Resolve on ANY connection state since we do not want
|
||||
# to wait the whole timeout if the device disconnects
|
||||
# or we get an error.
|
||||
if not connect_future.done():
|
||||
connect_future.set_result(None)
|
||||
|
||||
async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many-branches
|
||||
self,
|
||||
address: int,
|
||||
@ -584,7 +564,7 @@ class APIClient:
|
||||
address_type=address_type or 0,
|
||||
),
|
||||
partial(
|
||||
self._on_bluetooth_device_connection_response,
|
||||
on_bluetooth_device_connection_response,
|
||||
connect_future,
|
||||
address,
|
||||
on_bluetooth_connection_state,
|
||||
@ -594,7 +574,7 @@ class APIClient:
|
||||
|
||||
loop = self._loop
|
||||
timeout_handle = loop.call_at(
|
||||
loop.time() + timeout, self._handle_timeout, connect_future
|
||||
loop.time() + timeout, handle_timeout, connect_future
|
||||
)
|
||||
timeout_expired = False
|
||||
connect_ok = False
|
||||
@ -602,6 +582,11 @@ class APIClient:
|
||||
await connect_future
|
||||
connect_ok = True
|
||||
except asyncio.TimeoutError as err:
|
||||
# If the timeout expires, make sure
|
||||
# to unsub before calling _bluetooth_device_disconnect_guard_timeout
|
||||
# so that the disconnect message is not propagated back to the caller
|
||||
# since we are going to raise a TimeoutAPIError.
|
||||
unsub()
|
||||
timeout_expired = True
|
||||
# Disconnect before raising the exception to ensure
|
||||
# the slot is recovered before the timeout is raised
|
||||
@ -620,14 +605,8 @@ class APIClient:
|
||||
f" after {disconnect_timeout}s"
|
||||
) from err
|
||||
finally:
|
||||
if not connect_ok:
|
||||
try:
|
||||
unsub()
|
||||
except (KeyError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"%s: Bluetooth device connection canceled but already unsubscribed",
|
||||
to_human_readable_address(address),
|
||||
)
|
||||
if not connect_ok and not timeout_expired:
|
||||
unsub()
|
||||
if not timeout_expired:
|
||||
timeout_handle.cancel()
|
||||
|
||||
|
@ -8,3 +8,5 @@ cdef object CameraImageResponse, CameraState
|
||||
cdef object HomeassistantServiceCall
|
||||
|
||||
cdef object BluetoothLEAdvertisement
|
||||
|
||||
cdef object asyncio_TimeoutError
|
||||
|
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from google.protobuf import message
|
||||
|
||||
from .api_pb2 import ( # type: ignore
|
||||
BluetoothConnectionsFreeResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothGATTNotifyDataResponse,
|
||||
BluetoothLEAdvertisementResponse,
|
||||
BluetoothLERawAdvertisement,
|
||||
@ -93,3 +96,25 @@ def on_subscribe_home_assistant_state_response(
|
||||
msg: SubscribeHomeAssistantStateResponse,
|
||||
) -> None:
|
||||
on_state_sub(msg.entity_id, msg.attribute)
|
||||
|
||||
|
||||
def handle_timeout(fut: Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
|
||||
def on_bluetooth_device_connection_response(
|
||||
connect_future: Future[None],
|
||||
address: int,
|
||||
on_bluetooth_connection_state: Callable[[bool, int, int], None],
|
||||
msg: BluetoothDeviceConnectionResponse,
|
||||
) -> None:
|
||||
"""Handle a BluetoothDeviceConnectionResponse message.""" ""
|
||||
if address == msg.address:
|
||||
on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
|
||||
# Resolve on ANY connection state since we do not want
|
||||
# to wait the whole timeout if the device disconnects
|
||||
# or we get an error.
|
||||
if not connect_future.done():
|
||||
connect_future.set_result(None)
|
||||
|
@ -17,6 +17,7 @@ from aioesphomeapi.api_pb2 import (
|
||||
BluetoothDeviceClearCacheResponse,
|
||||
BluetoothDeviceConnectionResponse,
|
||||
BluetoothDevicePairingResponse,
|
||||
BluetoothDeviceRequest,
|
||||
BluetoothDeviceUnpairingResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
BluetoothGATTGetServicesDoneResponse,
|
||||
@ -68,10 +69,12 @@ from aioesphomeapi.model import (
|
||||
APIVersion,
|
||||
BinarySensorInfo,
|
||||
BinarySensorState,
|
||||
BluetoothDeviceRequestType,
|
||||
)
|
||||
from aioesphomeapi.model import BluetoothGATTService as BluetoothGATTServiceModel
|
||||
from aioesphomeapi.model import (
|
||||
BluetoothLEAdvertisement,
|
||||
BluetoothProxyFeature,
|
||||
CameraState,
|
||||
ClimateFanMode,
|
||||
ClimateMode,
|
||||
@ -1455,3 +1458,195 @@ async def test_force_disconnect(
|
||||
assert connection.is_connected is False
|
||||
await client.disconnect(force=False)
|
||||
assert connection.is_connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("has_cache", "feature_flags", "method"),
|
||||
[
|
||||
(False, BluetoothProxyFeature(0), BluetoothDeviceRequestType.CONNECT),
|
||||
(
|
||||
False,
|
||||
BluetoothProxyFeature.REMOTE_CACHING,
|
||||
BluetoothDeviceRequestType.CONNECT_V3_WITHOUT_CACHE,
|
||||
),
|
||||
(
|
||||
True,
|
||||
BluetoothProxyFeature.REMOTE_CACHING,
|
||||
BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_bluetooth_device_connect(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
has_cache: bool,
|
||||
feature_flags: BluetoothProxyFeature,
|
||||
method: BluetoothDeviceRequestType,
|
||||
) -> None:
|
||||
"""Test bluetooth_device_connect."""
|
||||
client, connection, transport, protocol = api_client
|
||||
states = []
|
||||
|
||||
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
|
||||
states.append((connected, mtu, error))
|
||||
|
||||
connect_task = asyncio.create_task(
|
||||
client.bluetooth_device_connect(
|
||||
1234,
|
||||
on_bluetooth_connection_state,
|
||||
timeout=1,
|
||||
feature_flags=feature_flags,
|
||||
has_cache=has_cache,
|
||||
disconnect_timeout=1,
|
||||
address_type=1,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=True, mtu=23, error=0
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
|
||||
cancel = await connect_task
|
||||
assert states == [(True, 23, 0)]
|
||||
transport.write.assert_called_once_with(
|
||||
generate_plaintext_packet(
|
||||
BluetoothDeviceRequest(
|
||||
address=1234,
|
||||
request_type=method,
|
||||
has_address_type=True,
|
||||
address_type=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, mtu=23, error=7
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert states == [(True, 23, 0), (False, 23, 7)]
|
||||
cancel()
|
||||
|
||||
# After cancel, no more messages should called back
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, mtu=23, error=8
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert states == [(True, 23, 0), (False, 23, 7)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_device_connect_and_disconnect_times_out(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_connect and disconnect times out."""
|
||||
client, connection, transport, protocol = api_client
|
||||
states = []
|
||||
|
||||
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
|
||||
states.append((connected, mtu, error))
|
||||
|
||||
connect_task = asyncio.create_task(
|
||||
client.bluetooth_device_connect(
|
||||
1234,
|
||||
on_bluetooth_connection_state,
|
||||
timeout=0,
|
||||
feature_flags=0,
|
||||
has_cache=True,
|
||||
disconnect_timeout=0,
|
||||
address_type=1,
|
||||
)
|
||||
)
|
||||
with pytest.raises(TimeoutAPIError):
|
||||
await connect_task
|
||||
assert states == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_device_connect_times_out_disconnect_ok(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_connect and disconnect times out."""
|
||||
client, connection, transport, protocol = api_client
|
||||
states = []
|
||||
|
||||
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
|
||||
states.append((connected, mtu, error))
|
||||
|
||||
connect_task = asyncio.create_task(
|
||||
client.bluetooth_device_connect(
|
||||
1234,
|
||||
on_bluetooth_connection_state,
|
||||
timeout=0,
|
||||
feature_flags=0,
|
||||
has_cache=True,
|
||||
disconnect_timeout=1,
|
||||
address_type=1,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# The connect request should be written
|
||||
assert len(transport.write.mock_calls) == 1
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
# Now that we timed out, the disconnect
|
||||
# request should be written
|
||||
assert len(transport.write.mock_calls) == 2
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, mtu=23, error=8
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
with pytest.raises(TimeoutAPIError):
|
||||
await connect_task
|
||||
assert states == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_device_connect_cancelled(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_device_connect handles cancellation."""
|
||||
client, connection, transport, protocol = api_client
|
||||
states = []
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
|
||||
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
|
||||
states.append((connected, mtu, error))
|
||||
|
||||
connect_task = asyncio.create_task(
|
||||
client.bluetooth_device_connect(
|
||||
1234,
|
||||
on_bluetooth_connection_state,
|
||||
timeout=10,
|
||||
feature_flags=0,
|
||||
has_cache=True,
|
||||
disconnect_timeout=10,
|
||||
address_type=1,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# The connect request should be written
|
||||
assert len(transport.write.mock_calls) == 1
|
||||
connect_task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await connect_task
|
||||
assert states == []
|
||||
|
||||
handlers_after = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
# Make sure we do not leak message handlers
|
||||
assert handlers_after == handlers_before
|
||||
|
Loading…
Reference in New Issue
Block a user