Fix missed GATT notify if the device responds immediately after subscribe (#669)

This commit is contained in:
J. Nick Koston 2023-11-23 16:46:56 +01:00 committed by GitHub
parent cf2fd3c92a
commit 1cc6b3ed52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 18 deletions

View File

@ -978,6 +978,17 @@ class APIClient:
timeout=timeout,
)
def _on_bluetooth_gatt_notify_data_response(
self,
address: int,
handle: int,
on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
msg: BluetoothGATTNotifyDataResponse,
) -> None:
"""Handle a BluetoothGATTNotifyDataResponse message."""
if address == msg.address and handle == msg.handle:
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
async def bluetooth_gatt_start_notify(
self,
address: int,
@ -994,23 +1005,27 @@ class APIClient:
callbacks without stopping the notify session on the remote device, which
should be used when the connection is lost.
"""
await self._send_bluetooth_message_await_response(
address,
handle,
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
BluetoothGATTNotifyResponse,
)
def _on_bluetooth_gatt_notify_data_response(
msg: BluetoothGATTNotifyDataResponse,
) -> None:
if address == msg.address and handle == msg.handle:
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
remove_callback = self._get_connection().add_message_callback(
_on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,)
partial(
self._on_bluetooth_gatt_notify_data_response,
address,
handle,
on_bluetooth_gatt_notify,
),
(BluetoothGATTNotifyDataResponse,),
)
try:
await self._send_bluetooth_message_await_response(
address,
handle,
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
BluetoothGATTNotifyResponse,
)
except Exception:
remove_callback()
raise
async def stop_notify() -> None:
if self._connection is None:
return

View File

@ -946,3 +946,15 @@ class APIConnection:
)
self._cleanup()
def _get_message_handlers(
self,
) -> dict[Any, set[Callable[[message.Message], None]]]:
"""Get the message handlers.
This function is only used for testing for leaks.
It has to be bound to the real instance to work since
_message_handlers is not a public attribute.
"""
return self._message_handlers

View File

@ -19,6 +19,10 @@ from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
KEEP_ALIVE_INTERVAL = 15.0
class PatchableAPIConnection(APIConnection):
pass
@pytest.fixture
def async_zeroconf():
return get_mock_async_zeroconf()
@ -76,12 +80,12 @@ async def on_stop(expected_disconnect: bool) -> None:
@pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection:
return APIConnection(connection_params, on_stop)
return PatchableAPIConnection(connection_params, on_stop)
@pytest.fixture
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
return APIConnection(noise_connection_params, on_stop)
return PatchableAPIConnection(noise_connection_params, on_stop)
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import itertools
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
@ -1130,7 +1131,6 @@ async def test_bluetooth_gatt_get_services_errors(
await services_task
@pytest.mark.xfail(reason="There is a race condition here")
@pytest.mark.asyncio
async def test_bluetooth_gatt_start_notify(
api_client: tuple[
@ -1141,6 +1141,10 @@ async def test_bluetooth_gatt_start_notify(
client, connection, transport, protocol = api_client
notifies = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data))
@ -1159,7 +1163,7 @@ async def test_bluetooth_gatt_start_notify(
+ generate_plaintext_packet(data_response)
)
await notify_task
cancel_cb, abort_cb = await notify_task
assert notifies == [(1, b"gotit")]
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
@ -1167,6 +1171,45 @@ async def test_bluetooth_gatt_start_notify(
)
protocol.data_received(generate_plaintext_packet(second_data_response))
assert notifies == [(1, b"gotit"), (1, b"after finished")]
await cancel_cb()
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
== handlers_before
)
# Ensure abort callback is a no-op after cancel
# and doesn't raise
abort_cb()
@pytest.mark.asyncio
async def test_bluetooth_gatt_start_notify_fails(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_gatt_start_notify failure does not leak."""
client, connection, transport, protocol = api_client
notifies = []
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data))
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
with patch.object(
connection,
"send_messages_await_response_complex",
side_effect=APIConnectionError,
), pytest.raises(APIConnectionError):
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
== handlers_before
)
@pytest.mark.asyncio