Reduce duplicate Bluetooth message filtering code (#777)

This commit is contained in:
J. Nick Koston 2023-11-28 10:49:58 -06:00 committed by GitHub
parent 5c063e2269
commit d40acb1f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 41 deletions

View File

@ -76,8 +76,9 @@ from .client_callbacks import (
on_bluetooth_connections_free_response, on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response, on_bluetooth_device_connection_response,
on_bluetooth_gatt_notify_data_response, on_bluetooth_gatt_notify_data_response,
on_bluetooth_handle_message,
on_bluetooth_le_advertising_response, on_bluetooth_le_advertising_response,
on_bluetooth_message, on_bluetooth_message_types,
on_home_assistant_service_response, on_home_assistant_service_response,
on_state_msg, on_state_msg,
on_subscribe_home_assistant_state_response, on_subscribe_home_assistant_state_response,
@ -465,7 +466,7 @@ class APIClient:
), ),
timeout: float = 10.0, timeout: float = 10.0,
) -> message.Message: ) -> message.Message:
message_filter = partial(on_bluetooth_message, address, handle) message_filter = partial(on_bluetooth_handle_message, address, handle)
msg_types = (response_type, BluetoothGATTErrorResponse) msg_types = (response_type, BluetoothGATTErrorResponse)
[resp] = await self._get_connection().send_messages_await_response_complex( [resp] = await self._get_connection().send_messages_await_response_complex(
(request,), (request,),
@ -684,11 +685,12 @@ class APIClient:
timeout: float, timeout: float,
) -> message.Message: ) -> message.Message:
"""Send a BluetoothDeviceRequest watch for the connection state to change.""" """Send a BluetoothDeviceRequest watch for the connection state to change."""
types_with_response = (BluetoothDeviceConnectionResponse, *msg_types)
response = await self._bluetooth_device_request( response = await self._bluetooth_device_request(
address, address,
request_type, request_type,
lambda msg: msg.address == address, partial(on_bluetooth_message_types, address, types_with_response),
(BluetoothDeviceConnectionResponse, *msg_types), types_with_response,
timeout, timeout,
) )
self._raise_for_ble_connection_change(address, response, msg_types) self._raise_for_ble_connection_change(address, response, msg_types)
@ -720,13 +722,9 @@ class APIClient:
timeout: float, timeout: float,
) -> message.Message: ) -> message.Message:
"""Send a BluetoothDeviceRequest and wait for a response.""" """Send a BluetoothDeviceRequest and wait for a response."""
req = BluetoothDeviceRequest(address=address, request_type=request_type)
[response] = await self._get_connection().send_messages_await_response_complex( [response] = await self._get_connection().send_messages_await_response_complex(
( (req,),
BluetoothDeviceRequest(
address=address,
request_type=request_type,
),
),
predicate_func, predicate_func,
predicate_func, predicate_func,
msg_types, msg_types,
@ -737,42 +735,18 @@ class APIClient:
async def bluetooth_gatt_get_services( async def bluetooth_gatt_get_services(
self, address: int self, address: int
) -> ESPHomeBluetoothGATTServices: ) -> ESPHomeBluetoothGATTServices:
append_types = ( error_types = (BluetoothGATTErrorResponse, BluetoothDeviceConnectionResponse)
BluetoothDeviceConnectionResponse, append_types = (*error_types, BluetoothGATTGetServicesResponse)
BluetoothGATTGetServicesResponse, stop_types = (*error_types, BluetoothGATTGetServicesDoneResponse)
BluetoothGATTErrorResponse,
)
stop_types = (
BluetoothDeviceConnectionResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
msg_types = ( msg_types = (
BluetoothGATTGetServicesResponse, BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse, BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse, BluetoothGATTErrorResponse,
) )
def do_append(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in append_types and msg.address == address
def do_stop(
msg: BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
return type(msg) in stop_types and msg.address == address
resp = await self._get_connection().send_messages_await_response_complex( resp = await self._get_connection().send_messages_await_response_complex(
(BluetoothGATTGetServicesRequest(address=address),), (BluetoothGATTGetServicesRequest(address=address),),
do_append, partial(on_bluetooth_message_types, address, append_types),
do_stop, partial(on_bluetooth_message_types, address, stop_types),
(*msg_types, BluetoothDeviceConnectionResponse), (*msg_types, BluetoothDeviceConnectionResponse),
DEFAULT_BLE_TIMEOUT, DEFAULT_BLE_TIMEOUT,
) )

View File

@ -10,6 +10,8 @@ from .api_pb2 import ( # type: ignore
BluetoothConnectionsFreeResponse, BluetoothConnectionsFreeResponse,
BluetoothDeviceConnectionResponse, BluetoothDeviceConnectionResponse,
BluetoothGATTErrorResponse, BluetoothGATTErrorResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTGetServicesResponse,
BluetoothGATTNotifyDataResponse, BluetoothGATTNotifyDataResponse,
BluetoothGATTNotifyResponse, BluetoothGATTNotifyResponse,
BluetoothGATTReadResponse, BluetoothGATTReadResponse,
@ -118,7 +120,7 @@ def on_bluetooth_device_connection_response(
connect_future.set_result(None) connect_future.set_result(None)
def on_bluetooth_message( def on_bluetooth_handle_message(
address: int, address: int,
handle: int, handle: int,
msg: BluetoothGATTErrorResponse msg: BluetoothGATTErrorResponse
@ -127,7 +129,23 @@ def on_bluetooth_message(
| BluetoothGATTWriteResponse | BluetoothGATTWriteResponse
| BluetoothDeviceConnectionResponse, | BluetoothDeviceConnectionResponse,
) -> bool: ) -> bool:
"""Handle a Bluetooth message.""" """Filter a Bluetooth message for an address and handle."""
if type(msg) is BluetoothDeviceConnectionResponse: if type(msg) is BluetoothDeviceConnectionResponse:
return bool(msg.address == address) return bool(msg.address == address)
return bool(msg.address == address and msg.handle == handle) return bool(msg.address == address and msg.handle == handle)
def on_bluetooth_message_types(
address: int,
msg_types: tuple[type[message.Message]],
msg: BluetoothGATTErrorResponse
| BluetoothGATTNotifyResponse
| BluetoothGATTReadResponse
| BluetoothGATTWriteResponse
| BluetoothDeviceConnectionResponse
| BluetoothGATTGetServicesResponse
| BluetoothGATTGetServicesDoneResponse
| BluetoothGATTErrorResponse,
) -> bool:
"""Filter Bluetooth messages of a specific type and address."""
return type(msg) in msg_types and bool(msg.address == address)