From 1cc6b3ed525110e9e1d1d69dc196ab41f76c243b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Nov 2023 16:46:56 +0100 Subject: [PATCH] Fix missed GATT notify if the device responds immediately after subscribe (#669) --- aioesphomeapi/client.py | 43 ++++++++++++++++++++++----------- aioesphomeapi/connection.py | 12 ++++++++++ tests/conftest.py | 8 +++++-- tests/test_client.py | 47 +++++++++++++++++++++++++++++++++++-- 4 files changed, 92 insertions(+), 18 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 6f0f5ad..4808536 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -978,6 +978,17 @@ class APIClient: timeout=timeout, ) + def _on_bluetooth_gatt_notify_data_response( + self, + address: int, + handle: int, + on_bluetooth_gatt_notify: Callable[[int, bytearray], None], + msg: BluetoothGATTNotifyDataResponse, + ) -> None: + """Handle a BluetoothGATTNotifyDataResponse message.""" + if address == msg.address and handle == msg.handle: + on_bluetooth_gatt_notify(handle, bytearray(msg.data)) + async def bluetooth_gatt_start_notify( self, address: int, @@ -994,23 +1005,27 @@ class APIClient: callbacks without stopping the notify session on the remote device, which should be used when the connection is lost. """ - await self._send_bluetooth_message_await_response( - address, - handle, - BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True), - BluetoothGATTNotifyResponse, - ) - - def _on_bluetooth_gatt_notify_data_response( - msg: BluetoothGATTNotifyDataResponse, - ) -> None: - if address == msg.address and handle == msg.handle: - on_bluetooth_gatt_notify(handle, bytearray(msg.data)) - remove_callback = self._get_connection().add_message_callback( - _on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,) + partial( + self._on_bluetooth_gatt_notify_data_response, + address, + handle, + on_bluetooth_gatt_notify, + ), + (BluetoothGATTNotifyDataResponse,), ) + try: + await self._send_bluetooth_message_await_response( + address, + handle, + BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True), + BluetoothGATTNotifyResponse, + ) + except Exception: + remove_callback() + raise + async def stop_notify() -> None: if self._connection is None: return diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index a30a348..4d97065 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -946,3 +946,15 @@ class APIConnection: ) self._cleanup() + + def _get_message_handlers( + self, + ) -> dict[Any, set[Callable[[message.Message], None]]]: + """Get the message handlers. + + This function is only used for testing for leaks. + + It has to be bound to the real instance to work since + _message_handlers is not a public attribute. + """ + return self._message_handlers diff --git a/tests/conftest.py b/tests/conftest.py index bd956b6..42585d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,10 @@ from .common import connect, get_mock_async_zeroconf, send_plaintext_hello KEEP_ALIVE_INTERVAL = 15.0 +class PatchableAPIConnection(APIConnection): + pass + + @pytest.fixture def async_zeroconf(): return get_mock_async_zeroconf() @@ -76,12 +80,12 @@ async def on_stop(expected_disconnect: bool) -> None: @pytest.fixture def conn(connection_params: ConnectionParams) -> APIConnection: - return APIConnection(connection_params, on_stop) + return PatchableAPIConnection(connection_params, on_stop) @pytest.fixture def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection: - return APIConnection(noise_connection_params, on_stop) + return PatchableAPIConnection(noise_connection_params, on_stop) @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") diff --git a/tests/test_client.py b/tests/test_client.py index a3abe3f..faed270 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import itertools from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -1130,7 +1131,6 @@ async def test_bluetooth_gatt_get_services_errors( await services_task -@pytest.mark.xfail(reason="There is a race condition here") @pytest.mark.asyncio async def test_bluetooth_gatt_start_notify( api_client: tuple[ @@ -1141,6 +1141,10 @@ async def test_bluetooth_gatt_start_notify( client, connection, transport, protocol = api_client notifies = [] + handlers_before = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: notifies.append((handle, data)) @@ -1159,7 +1163,7 @@ async def test_bluetooth_gatt_start_notify( + generate_plaintext_packet(data_response) ) - await notify_task + cancel_cb, abort_cb = await notify_task assert notifies == [(1, b"gotit")] second_data_response: message.Message = BluetoothGATTNotifyDataResponse( @@ -1167,6 +1171,45 @@ async def test_bluetooth_gatt_start_notify( ) protocol.data_received(generate_plaintext_packet(second_data_response)) assert notifies == [(1, b"gotit"), (1, b"after finished")] + await cancel_cb() + + assert ( + len(list(itertools.chain(*connection._get_message_handlers().values()))) + == handlers_before + ) + # Ensure abort callback is a no-op after cancel + # and doesn't raise + abort_cb() + + +@pytest.mark.asyncio +async def test_bluetooth_gatt_start_notify_fails( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_gatt_start_notify failure does not leak.""" + client, connection, transport, protocol = api_client + notifies = [] + + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: + notifies.append((handle, data)) + + handlers_before = len( + list(itertools.chain(*connection._get_message_handlers().values())) + ) + + with patch.object( + connection, + "send_messages_await_response_complex", + side_effect=APIConnectionError, + ), pytest.raises(APIConnectionError): + await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) + + assert ( + len(list(itertools.chain(*connection._get_message_handlers().values()))) + == handlers_before + ) @pytest.mark.asyncio