Fix missed GATT notify if the device responds immediately after subscribe (#669)

This commit is contained in:
J. Nick Koston 2023-11-23 16:46:56 +01:00 committed by GitHub
parent cf2fd3c92a
commit 1cc6b3ed52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 18 deletions

View File

@ -978,6 +978,17 @@ class APIClient:
timeout=timeout, timeout=timeout,
) )
def _on_bluetooth_gatt_notify_data_response(
self,
address: int,
handle: int,
on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
msg: BluetoothGATTNotifyDataResponse,
) -> None:
"""Handle a BluetoothGATTNotifyDataResponse message."""
if address == msg.address and handle == msg.handle:
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
async def bluetooth_gatt_start_notify( async def bluetooth_gatt_start_notify(
self, self,
address: int, address: int,
@ -994,23 +1005,27 @@ class APIClient:
callbacks without stopping the notify session on the remote device, which callbacks without stopping the notify session on the remote device, which
should be used when the connection is lost. should be used when the connection is lost.
""" """
await self._send_bluetooth_message_await_response(
address,
handle,
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
BluetoothGATTNotifyResponse,
)
def _on_bluetooth_gatt_notify_data_response(
msg: BluetoothGATTNotifyDataResponse,
) -> None:
if address == msg.address and handle == msg.handle:
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
remove_callback = self._get_connection().add_message_callback( remove_callback = self._get_connection().add_message_callback(
_on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,) partial(
self._on_bluetooth_gatt_notify_data_response,
address,
handle,
on_bluetooth_gatt_notify,
),
(BluetoothGATTNotifyDataResponse,),
) )
try:
await self._send_bluetooth_message_await_response(
address,
handle,
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
BluetoothGATTNotifyResponse,
)
except Exception:
remove_callback()
raise
async def stop_notify() -> None: async def stop_notify() -> None:
if self._connection is None: if self._connection is None:
return return

View File

@ -946,3 +946,15 @@ class APIConnection:
) )
self._cleanup() self._cleanup()
def _get_message_handlers(
self,
) -> dict[Any, set[Callable[[message.Message], None]]]:
"""Get the message handlers.
This function is only used for testing for leaks.
It has to be bound to the real instance to work since
_message_handlers is not a public attribute.
"""
return self._message_handlers

View File

@ -19,6 +19,10 @@ from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
KEEP_ALIVE_INTERVAL = 15.0 KEEP_ALIVE_INTERVAL = 15.0
class PatchableAPIConnection(APIConnection):
pass
@pytest.fixture @pytest.fixture
def async_zeroconf(): def async_zeroconf():
return get_mock_async_zeroconf() return get_mock_async_zeroconf()
@ -76,12 +80,12 @@ async def on_stop(expected_disconnect: bool) -> None:
@pytest.fixture @pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection: def conn(connection_params: ConnectionParams) -> APIConnection:
return APIConnection(connection_params, on_stop) return PatchableAPIConnection(connection_params, on_stop)
@pytest.fixture @pytest.fixture
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection: def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
return APIConnection(noise_connection_params, on_stop) return PatchableAPIConnection(noise_connection_params, on_stop)
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login") @pytest_asyncio.fixture(name="plaintext_connect_task_no_login")

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import itertools
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
@ -1130,7 +1131,6 @@ async def test_bluetooth_gatt_get_services_errors(
await services_task await services_task
@pytest.mark.xfail(reason="There is a race condition here")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bluetooth_gatt_start_notify( async def test_bluetooth_gatt_start_notify(
api_client: tuple[ api_client: tuple[
@ -1141,6 +1141,10 @@ async def test_bluetooth_gatt_start_notify(
client, connection, transport, protocol = api_client client, connection, transport, protocol = api_client
notifies = [] notifies = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data)) notifies.append((handle, data))
@ -1159,7 +1163,7 @@ async def test_bluetooth_gatt_start_notify(
+ generate_plaintext_packet(data_response) + generate_plaintext_packet(data_response)
) )
await notify_task cancel_cb, abort_cb = await notify_task
assert notifies == [(1, b"gotit")] assert notifies == [(1, b"gotit")]
second_data_response: message.Message = BluetoothGATTNotifyDataResponse( second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
@ -1167,6 +1171,45 @@ async def test_bluetooth_gatt_start_notify(
) )
protocol.data_received(generate_plaintext_packet(second_data_response)) protocol.data_received(generate_plaintext_packet(second_data_response))
assert notifies == [(1, b"gotit"), (1, b"after finished")] assert notifies == [(1, b"gotit"), (1, b"after finished")]
await cancel_cb()
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
== handlers_before
)
# Ensure abort callback is a no-op after cancel
# and doesn't raise
abort_cb()
@pytest.mark.asyncio
async def test_bluetooth_gatt_start_notify_fails(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_gatt_start_notify failure does not leak."""
client, connection, transport, protocol = api_client
notifies = []
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data))
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
with patch.object(
connection,
"send_messages_await_response_complex",
side_effect=APIConnectionError,
), pytest.raises(APIConnectionError):
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
== handlers_before
)
@pytest.mark.asyncio @pytest.mark.asyncio