mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-30 18:08:36 +01:00
Fix missed GATT notify if the device responds immediately after subscribe (#669)
This commit is contained in:
parent
cf2fd3c92a
commit
1cc6b3ed52
@ -978,6 +978,17 @@ class APIClient:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _on_bluetooth_gatt_notify_data_response(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
handle: int,
|
||||||
|
on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
|
||||||
|
msg: BluetoothGATTNotifyDataResponse,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a BluetoothGATTNotifyDataResponse message."""
|
||||||
|
if address == msg.address and handle == msg.handle:
|
||||||
|
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
|
||||||
|
|
||||||
async def bluetooth_gatt_start_notify(
|
async def bluetooth_gatt_start_notify(
|
||||||
self,
|
self,
|
||||||
address: int,
|
address: int,
|
||||||
@ -994,22 +1005,26 @@ class APIClient:
|
|||||||
callbacks without stopping the notify session on the remote device, which
|
callbacks without stopping the notify session on the remote device, which
|
||||||
should be used when the connection is lost.
|
should be used when the connection is lost.
|
||||||
"""
|
"""
|
||||||
|
remove_callback = self._get_connection().add_message_callback(
|
||||||
|
partial(
|
||||||
|
self._on_bluetooth_gatt_notify_data_response,
|
||||||
|
address,
|
||||||
|
handle,
|
||||||
|
on_bluetooth_gatt_notify,
|
||||||
|
),
|
||||||
|
(BluetoothGATTNotifyDataResponse,),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
await self._send_bluetooth_message_await_response(
|
await self._send_bluetooth_message_await_response(
|
||||||
address,
|
address,
|
||||||
handle,
|
handle,
|
||||||
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
|
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=True),
|
||||||
BluetoothGATTNotifyResponse,
|
BluetoothGATTNotifyResponse,
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
def _on_bluetooth_gatt_notify_data_response(
|
remove_callback()
|
||||||
msg: BluetoothGATTNotifyDataResponse,
|
raise
|
||||||
) -> None:
|
|
||||||
if address == msg.address and handle == msg.handle:
|
|
||||||
on_bluetooth_gatt_notify(handle, bytearray(msg.data))
|
|
||||||
|
|
||||||
remove_callback = self._get_connection().add_message_callback(
|
|
||||||
_on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def stop_notify() -> None:
|
async def stop_notify() -> None:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
|
@ -946,3 +946,15 @@ class APIConnection:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
|
def _get_message_handlers(
|
||||||
|
self,
|
||||||
|
) -> dict[Any, set[Callable[[message.Message], None]]]:
|
||||||
|
"""Get the message handlers.
|
||||||
|
|
||||||
|
This function is only used for testing for leaks.
|
||||||
|
|
||||||
|
It has to be bound to the real instance to work since
|
||||||
|
_message_handlers is not a public attribute.
|
||||||
|
"""
|
||||||
|
return self._message_handlers
|
||||||
|
@ -19,6 +19,10 @@ from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
|
|||||||
KEEP_ALIVE_INTERVAL = 15.0
|
KEEP_ALIVE_INTERVAL = 15.0
|
||||||
|
|
||||||
|
|
||||||
|
class PatchableAPIConnection(APIConnection):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def async_zeroconf():
|
def async_zeroconf():
|
||||||
return get_mock_async_zeroconf()
|
return get_mock_async_zeroconf()
|
||||||
@ -76,12 +80,12 @@ async def on_stop(expected_disconnect: bool) -> None:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def conn(connection_params: ConnectionParams) -> APIConnection:
|
def conn(connection_params: ConnectionParams) -> APIConnection:
|
||||||
return APIConnection(connection_params, on_stop)
|
return PatchableAPIConnection(connection_params, on_stop)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
||||||
return APIConnection(noise_connection_params, on_stop)
|
return PatchableAPIConnection(noise_connection_params, on_stop)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import itertools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
@ -1130,7 +1131,6 @@ async def test_bluetooth_gatt_get_services_errors(
|
|||||||
await services_task
|
await services_task
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="There is a race condition here")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bluetooth_gatt_start_notify(
|
async def test_bluetooth_gatt_start_notify(
|
||||||
api_client: tuple[
|
api_client: tuple[
|
||||||
@ -1141,6 +1141,10 @@ async def test_bluetooth_gatt_start_notify(
|
|||||||
client, connection, transport, protocol = api_client
|
client, connection, transport, protocol = api_client
|
||||||
notifies = []
|
notifies = []
|
||||||
|
|
||||||
|
handlers_before = len(
|
||||||
|
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||||
|
)
|
||||||
|
|
||||||
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||||
notifies.append((handle, data))
|
notifies.append((handle, data))
|
||||||
|
|
||||||
@ -1159,7 +1163,7 @@ async def test_bluetooth_gatt_start_notify(
|
|||||||
+ generate_plaintext_packet(data_response)
|
+ generate_plaintext_packet(data_response)
|
||||||
)
|
)
|
||||||
|
|
||||||
await notify_task
|
cancel_cb, abort_cb = await notify_task
|
||||||
assert notifies == [(1, b"gotit")]
|
assert notifies == [(1, b"gotit")]
|
||||||
|
|
||||||
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||||
@ -1167,6 +1171,45 @@ async def test_bluetooth_gatt_start_notify(
|
|||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(second_data_response))
|
protocol.data_received(generate_plaintext_packet(second_data_response))
|
||||||
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
||||||
|
await cancel_cb()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||||
|
== handlers_before
|
||||||
|
)
|
||||||
|
# Ensure abort callback is a no-op after cancel
|
||||||
|
# and doesn't raise
|
||||||
|
abort_cb()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bluetooth_gatt_start_notify_fails(
|
||||||
|
api_client: tuple[
|
||||||
|
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test bluetooth_gatt_start_notify failure does not leak."""
|
||||||
|
client, connection, transport, protocol = api_client
|
||||||
|
notifies = []
|
||||||
|
|
||||||
|
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||||
|
notifies.append((handle, data))
|
||||||
|
|
||||||
|
handlers_before = len(
|
||||||
|
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
connection,
|
||||||
|
"send_messages_await_response_complex",
|
||||||
|
side_effect=APIConnectionError,
|
||||||
|
), pytest.raises(APIConnectionError):
|
||||||
|
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||||
|
== handlers_before
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user