mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-27 17:37:39 +01:00
Fix missed GATT notify if the device responds immediately after subscribe (#669)
This commit is contained in:
parent
cf2fd3c92a
commit
1cc6b3ed52
@ -978,6 +978,17 @@ class APIClient:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _on_bluetooth_gatt_notify_data_response(
|
||||
self,
|
||||
address: int,
|
||||
handle: int,
|
||||
on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
|
||||
msg: BluetoothGATTNotifyDataResponse,
|
||||
) -> None:
|
||||
"""Handle a BluetoothGATTNotifyDataResponse message."""
|
||||
if address == msg.address and handle == msg.handle:
|
||||
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
|
||||
|
||||
async def bluetooth_gatt_start_notify(
|
||||
self,
|
||||
address: int,
|
||||
@ -994,23 +1005,27 @@ class APIClient:
|
||||
callbacks without stopping the notify session on the remote device, which
|
||||
should be used when the connection is lost.
|
||||
"""
|
||||
await self._send_bluetooth_message_await_response(
|
||||
address,
|
||||
handle,
|
||||
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
|
||||
BluetoothGATTNotifyResponse,
|
||||
)
|
||||
|
||||
def _on_bluetooth_gatt_notify_data_response(
|
||||
msg: BluetoothGATTNotifyDataResponse,
|
||||
) -> None:
|
||||
if address == msg.address and handle == msg.handle:
|
||||
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
|
||||
|
||||
remove_callback = self._get_connection().add_message_callback(
|
||||
_on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,)
|
||||
partial(
|
||||
self._on_bluetooth_gatt_notify_data_response,
|
||||
address,
|
||||
handle,
|
||||
on_bluetooth_gatt_notify,
|
||||
),
|
||||
(BluetoothGATTNotifyDataResponse,),
|
||||
)
|
||||
|
||||
try:
|
||||
await self._send_bluetooth_message_await_response(
|
||||
address,
|
||||
handle,
|
||||
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
|
||||
BluetoothGATTNotifyResponse,
|
||||
)
|
||||
except Exception:
|
||||
remove_callback()
|
||||
raise
|
||||
|
||||
async def stop_notify() -> None:
|
||||
if self._connection is None:
|
||||
return
|
||||
|
@ -946,3 +946,15 @@ class APIConnection:
|
||||
)
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def _get_message_handlers(
|
||||
self,
|
||||
) -> dict[Any, set[Callable[[message.Message], None]]]:
|
||||
"""Get the message handlers.
|
||||
|
||||
This function is only used for testing for leaks.
|
||||
|
||||
It has to be bound to the real instance to work since
|
||||
_message_handlers is not a public attribute.
|
||||
"""
|
||||
return self._message_handlers
|
||||
|
@ -19,6 +19,10 @@ from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
|
||||
KEEP_ALIVE_INTERVAL = 15.0
|
||||
|
||||
|
||||
class PatchableAPIConnection(APIConnection):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_zeroconf():
|
||||
return get_mock_async_zeroconf()
|
||||
@ -76,12 +80,12 @@ async def on_stop(expected_disconnect: bool) -> None:
|
||||
|
||||
@pytest.fixture
|
||||
def conn(connection_params: ConnectionParams) -> APIConnection:
|
||||
return APIConnection(connection_params, on_stop)
|
||||
return PatchableAPIConnection(connection_params, on_stop)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
||||
return APIConnection(noise_connection_params, on_stop)
|
||||
return PatchableAPIConnection(noise_connection_params, on_stop)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
@ -1130,7 +1131,6 @@ async def test_bluetooth_gatt_get_services_errors(
|
||||
await services_task
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="There is a race condition here")
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_gatt_start_notify(
|
||||
api_client: tuple[
|
||||
@ -1141,6 +1141,10 @@ async def test_bluetooth_gatt_start_notify(
|
||||
client, connection, transport, protocol = api_client
|
||||
notifies = []
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
|
||||
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||
notifies.append((handle, data))
|
||||
|
||||
@ -1159,7 +1163,7 @@ async def test_bluetooth_gatt_start_notify(
|
||||
+ generate_plaintext_packet(data_response)
|
||||
)
|
||||
|
||||
await notify_task
|
||||
cancel_cb, abort_cb = await notify_task
|
||||
assert notifies == [(1, b"gotit")]
|
||||
|
||||
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||
@ -1167,6 +1171,45 @@ async def test_bluetooth_gatt_start_notify(
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(second_data_response))
|
||||
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
||||
await cancel_cb()
|
||||
|
||||
assert (
|
||||
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||
== handlers_before
|
||||
)
|
||||
# Ensure abort callback is a no-op after cancel
|
||||
# and doesn't raise
|
||||
abort_cb()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_gatt_start_notify_fails(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_gatt_start_notify failure does not leak."""
|
||||
client, connection, transport, protocol = api_client
|
||||
notifies = []
|
||||
|
||||
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||
notifies.append((handle, data))
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connection,
|
||||
"send_messages_await_response_complex",
|
||||
side_effect=APIConnectionError,
|
||||
), pytest.raises(APIConnectionError):
|
||||
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
|
||||
|
||||
assert (
|
||||
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||
== handlers_before
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user