Callback messages to listeners by type (#328)

This commit is contained in:
J. Nick Koston 2022-12-02 09:36:58 -10:00 committed by GitHub
parent 056b9a3b79
commit de5cdfa230
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 111 deletions

View File

@ -312,16 +312,17 @@ class APIClient:
ListEntitiesLockResponse: LockInfo, ListEntitiesLockResponse: LockInfo,
ListEntitiesMediaPlayerResponse: MediaPlayerInfo, ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
} }
msg_types = (ListEntitiesDoneResponse, *response_types)
def do_append(msg: message.Message) -> bool: def do_append(msg: message.Message) -> bool:
return isinstance(msg, tuple(response_types.keys())) return not isinstance(msg, ListEntitiesDoneResponse)
def do_stop(msg: message.Message) -> bool: def do_stop(msg: message.Message) -> bool:
return isinstance(msg, ListEntitiesDoneResponse) return isinstance(msg, ListEntitiesDoneResponse)
assert self._connection is not None assert self._connection is not None
resp = await self._connection.send_message_await_response_complex( resp = await self._connection.send_message_await_response_complex(
ListEntitiesRequest(), do_append, do_stop, timeout=60 ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
) )
entities: List[EntityInfo] = [] entities: List[EntityInfo] = []
services: List[UserService] = [] services: List[UserService] = []
@ -329,19 +330,14 @@ class APIClient:
if isinstance(msg, ListEntitiesServicesResponse): if isinstance(msg, ListEntitiesServicesResponse):
services.append(UserService.from_pb(msg)) services.append(UserService.from_pb(msg))
continue continue
cls = None cls = response_types[type(msg)]
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
continue
assert cls is not None assert cls is not None
entities.append(cls.from_pb(msg)) entities.append(cls.from_pb(msg))
return entities, services return entities, services
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None: async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
self._check_authenticated() self._check_authenticated()
image_stream: Dict[int, bytes] = {}
response_types: Dict[Any, Type[EntityState]] = { response_types: Dict[Any, Type[EntityState]] = {
BinarySensorStateResponse: BinarySensorState, BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState, CoverStateResponse: CoverState,
@ -357,31 +353,24 @@ class APIClient:
LockStateResponse: LockEntityState, LockStateResponse: LockEntityState,
MediaPlayerStateResponse: MediaPlayerEntityState, MediaPlayerStateResponse: MediaPlayerEntityState,
} }
msg_types = (*response_types, CameraImageResponse)
image_stream: Dict[int, bytes] = {}
def on_msg(msg: message.Message) -> None: def on_msg(msg: message.Message) -> None:
if isinstance(msg, CameraImageResponse): msg_type = type(msg)
cls = response_types.get(msg_type)
if cls:
on_state(cls.from_pb(msg))
elif isinstance(msg, CameraImageResponse):
data = image_stream.pop(msg.key, bytes()) + msg.data data = image_stream.pop(msg.key, bytes()) + msg.data
if msg.done: if msg.done:
# Return CameraState with the merged data # Return CameraState with the merged data
on_state(CameraState(key=msg.key, data=data)) on_state(CameraState(key=msg.key, data=data))
else: else:
image_stream[msg.key] = data image_stream[msg.key] = data
return
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
return
# pylint: disable=undefined-loop-variable
on_state(cls.from_pb(msg))
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg SubscribeStatesRequest(), on_msg, msg_types
) )
async def subscribe_logs( async def subscribe_logs(
@ -392,9 +381,8 @@ class APIClient:
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg: message.Message) -> None: def on_msg(msg: SubscribeLogsResponse) -> None:
if isinstance(msg, SubscribeLogsResponse): on_log(msg)
on_log(msg)
req = SubscribeLogsRequest() req = SubscribeLogsRequest()
if log_level is not None: if log_level is not None:
@ -402,20 +390,23 @@ class APIClient:
if dump_config is not None: if dump_config is not None:
req.dump_config = dump_config req.dump_config = dump_config
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response(req, on_msg) await self._connection.send_message_callback_response(
req, on_msg, (SubscribeLogsResponse,)
)
async def subscribe_service_calls( async def subscribe_service_calls(
self, on_service_call: Callable[[HomeassistantServiceCall], None] self, on_service_call: Callable[[HomeassistantServiceCall], None]
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg: message.Message) -> None: def on_msg(msg: HomeassistantServiceResponse) -> None:
if isinstance(msg, HomeassistantServiceResponse): on_service_call(HomeassistantServiceCall.from_pb(msg))
on_service_call(HomeassistantServiceCall.from_pb(msg))
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
SubscribeHomeassistantServicesRequest(), on_msg SubscribeHomeassistantServicesRequest(),
on_msg,
(HomeassistantServiceResponse,),
) )
async def _send_bluetooth_message_await_response( async def _send_bluetooth_message_await_response(
@ -427,17 +418,18 @@ class APIClient:
timeout: float = 10.0, timeout: float = 10.0,
) -> message.Message: ) -> message.Message:
self._check_authenticated() self._check_authenticated()
msg_types = (response_type, BluetoothGATTErrorResponse)
assert self._connection is not None assert self._connection is not None
def is_response(msg: message.Message) -> bool: def is_response(msg: message.Message) -> bool:
return ( return (
isinstance(msg, (BluetoothGATTErrorResponse, response_type)) isinstance(msg, msg_types)
and msg.address == address # type: ignore[union-attr] and msg.address == address # type: ignore[union-attr]
and msg.handle == handle # type: ignore[union-attr] and msg.handle == handle # type: ignore[union-attr]
) )
resp = await self._connection.send_message_await_response_complex( resp = await self._connection.send_message_await_response_complex(
request, is_response, is_response, timeout=timeout request, is_response, is_response, msg_types, timeout=timeout
) )
if isinstance(resp[0], BluetoothGATTErrorResponse): if isinstance(resp[0], BluetoothGATTErrorResponse):
@ -449,19 +441,19 @@ class APIClient:
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None] self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothLEAdvertisementResponse,)
def on_msg(msg: message.Message) -> None: def on_msg(msg: BluetoothLEAdvertisementResponse) -> None:
if isinstance(msg, BluetoothLEAdvertisementResponse): on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(), on_msg SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
) )
def unsub() -> None: def unsub() -> None:
if self._connection is not None: if self._connection is not None:
self._connection.remove_message_callback(on_msg) self._connection.remove_message_callback(on_msg, msg_types)
return unsub return unsub
@ -469,24 +461,24 @@ class APIClient:
self, on_bluetooth_connections_free_update: Callable[[int, int], None] self, on_bluetooth_connections_free_update: Callable[[int, int], None]
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothConnectionsFreeResponse,)
def on_msg(msg: message.Message) -> None: def on_msg(msg: BluetoothConnectionsFreeResponse) -> None:
if isinstance(msg, BluetoothConnectionsFreeResponse): resp = BluetoothConnectionsFree.from_pb(msg)
resp = BluetoothConnectionsFree.from_pb(msg) on_bluetooth_connections_free_update(resp.free, resp.limit)
on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(), on_msg SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
) )
def unsub() -> None: def unsub() -> None:
if self._connection is not None: if self._connection is not None:
self._connection.remove_message_callback(on_msg) self._connection.remove_message_callback(on_msg, msg_types)
return unsub return unsub
async def bluetooth_device_connect( async def bluetooth_device_connect( # pylint: disable=too-many-locals
self, self,
address: int, address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None], on_bluetooth_connection_state: Callable[[bool, int, int], None],
@ -497,15 +489,15 @@ class APIClient:
address_type: Optional[int] = None, address_type: Optional[int] = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,)
event = asyncio.Event() event = asyncio.Event()
def on_msg(msg: message.Message) -> None: def on_msg(msg: BluetoothDeviceConnectionResponse) -> None:
if isinstance(msg, BluetoothDeviceConnectionResponse): resp = BluetoothDeviceConnection.from_pb(msg)
resp = BluetoothDeviceConnection.from_pb(msg) if address == resp.address:
if address == resp.address: on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error) event.set()
event.set()
assert self._connection is not None assert self._connection is not None
if has_cache: if has_cache:
@ -530,11 +522,12 @@ class APIClient:
address_type=address_type or 0, address_type=address_type or 0,
), ),
on_msg, on_msg,
msg_types,
) )
def unsub() -> None: def unsub() -> None:
if self._connection is not None: if self._connection is not None:
self._connection.remove_message_callback(on_msg) self._connection.remove_message_callback(on_msg, msg_types)
try: try:
try: try:
@ -558,7 +551,7 @@ class APIClient:
) )
try: try:
unsub() unsub()
except ValueError: except (KeyError, ValueError):
_LOGGER.warning( _LOGGER.warning(
"%s: Bluetooth device connection timed out but already unsubscribed", "%s: Bluetooth device connection timed out but already unsubscribed",
addr, addr,
@ -571,7 +564,7 @@ class APIClient:
except asyncio.CancelledError: except asyncio.CancelledError:
try: try:
unsub() unsub()
except ValueError: except (KeyError, ValueError):
_LOGGER.warning( _LOGGER.warning(
"%s: Bluetooth device connection canceled but already unsubscribed", "%s: Bluetooth device connection canceled but already unsubscribed",
addr, addr,
@ -595,29 +588,26 @@ class APIClient:
self, address: int self, address: int
) -> ESPHomeBluetoothGATTServices: ) -> ESPHomeBluetoothGATTServices:
self._check_authenticated() self._check_authenticated()
msg_types = (
BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse)
def do_append(msg: message.Message) -> bool: def do_append(msg: message.Message) -> bool:
return ( return isinstance(msg, append_types) and msg.address == address
isinstance(
msg, (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
)
and msg.address == address
)
def do_stop(msg: message.Message) -> bool: def do_stop(msg: message.Message) -> bool:
return ( return isinstance(msg, stop_types) and msg.address == address
isinstance(
msg,
(BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse),
)
and msg.address == address
)
assert self._connection is not None assert self._connection is not None
resp = await self._connection.send_message_await_response_complex( resp = await self._connection.send_message_await_response_complex(
BluetoothGATTGetServicesRequest(address=address), BluetoothGATTGetServicesRequest(address=address),
do_append, do_append,
do_stop, do_stop,
msg_types,
timeout=DEFAULT_BLE_TIMEOUT, timeout=DEFAULT_BLE_TIMEOUT,
) )
services = [] services = []
@ -740,14 +730,15 @@ class APIClient:
BluetoothGATTNotifyResponse, BluetoothGATTNotifyResponse,
) )
def on_msg(msg: message.Message) -> None: def on_msg(msg: BluetoothGATTNotifyDataResponse) -> None:
if isinstance(msg, BluetoothGATTNotifyDataResponse): notify = BluetoothGATTRead.from_pb(msg)
notify = BluetoothGATTRead.from_pb(msg) if address == notify.address and handle == notify.handle:
if address == notify.address and handle == notify.handle: on_bluetooth_gatt_notify(handle, bytearray(notify.data))
on_bluetooth_gatt_notify(handle, bytearray(notify.data))
assert self._connection is not None assert self._connection is not None
remove_callback = self._connection.add_message_callback(on_msg) remove_callback = self._connection.add_message_callback(
on_msg, (BluetoothGATTNotifyDataResponse,)
)
async def stop_notify() -> None: async def stop_notify() -> None:
if self._connection is None: if self._connection is None:
@ -768,13 +759,14 @@ class APIClient:
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg: message.Message) -> None: def on_msg(msg: SubscribeHomeAssistantStateResponse) -> None:
if isinstance(msg, SubscribeHomeAssistantStateResponse): on_state_sub(msg.entity_id, msg.attribute)
on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
SubscribeHomeAssistantStatesRequest(), on_msg SubscribeHomeAssistantStatesRequest(),
on_msg,
(SubscribeHomeAssistantStateResponse,),
) )
async def send_home_assistant_state( async def send_home_assistant_state(

View File

@ -5,7 +5,7 @@ import socket
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, List, Optional from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
import async_timeout import async_timeout
from google.protobuf import message from google.protobuf import message
@ -49,7 +49,9 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest) INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@dataclass @dataclass
@ -98,7 +100,7 @@ class APIConnection:
self._connect_complete = False self._connect_complete = False
# Message handlers currently subscribed to incoming messages # Message handlers currently subscribed to incoming messages
self._message_handlers: List[Callable[[message.Message], None]] = [] self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = params.address self.log_name = params.address
@ -384,12 +386,9 @@ class APIConnection:
if not self._is_socket_open: if not self._is_socket_open:
raise APIConnectionError("Connection isn't established yet") raise APIConnectionError("Connection isn't established yet")
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
if isinstance(msg, klass): if not message_type:
break
else:
raise ValueError(f"Message type id not found for type {type(msg)}") raise ValueError(f"Message type id not found for type {type(msg)}")
encoded = msg.SerializeToString() encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
@ -412,29 +411,39 @@ class APIConnection:
raise raise
def add_message_callback( def add_message_callback(
self, on_message: Callable[[Any], None] self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a message callback.""" """Add a message callback."""
self._message_handlers.append(on_message) for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
def unsub() -> None: def unsub() -> None:
self._message_handlers.remove(on_message) for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
return unsub return unsub
def remove_message_callback(self, on_message: Callable[[Any], None]) -> None: def remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> None:
"""Remove a message callback.""" """Remove a message callback."""
self._message_handlers.remove(on_message) for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
async def send_message_callback_response( async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None] self,
send_msg: message.Message,
on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]],
) -> None: ) -> None:
"""Send a message to the remote and register the given message handler.""" """Send a message to the remote and register the given message handler."""
self._message_handlers.append(on_message) for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
try: try:
await self.send_message(send_msg) await self.send_message(send_msg)
except asyncio.CancelledError: except asyncio.CancelledError:
self._message_handlers.remove(on_message) for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
raise raise
async def send_message_await_response_complex( async def send_message_await_response_complex(
@ -442,6 +451,7 @@ class APIConnection:
send_msg: message.Message, send_msg: message.Message,
do_append: Callable[[message.Message], bool], do_append: Callable[[message.Message], bool],
do_stop: Callable[[message.Message], bool], do_stop: Callable[[message.Message], bool],
msg_types: Iterable[Type[Any]],
timeout: float = 10.0, timeout: float = 10.0,
) -> List[message.Message]: ) -> List[message.Message]:
"""Send a message to the remote and build up a list response. """Send a message to the remote and build up a list response.
@ -472,11 +482,15 @@ class APIConnection:
new_exc.__cause__ = exc new_exc.__cause__ = exc
fut.set_exception(new_exc) fut.set_exception(new_exc)
self._message_handlers.append(on_message) for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
self._read_exception_handlers.append(on_read_exception) self._read_exception_handlers.append(on_read_exception)
await self.send_message(send_msg) # We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
try: try:
await self.send_message(send_msg)
async with async_timeout.timeout(timeout): async with async_timeout.timeout(timeout):
await fut await fut
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
@ -485,7 +499,8 @@ class APIConnection:
) from err ) from err
finally: finally:
with suppress(ValueError): with suppress(ValueError):
self._message_handlers.remove(on_message) for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
with suppress(ValueError): with suppress(ValueError):
self._read_exception_handlers.remove(on_read_exception) self._read_exception_handlers.remove(on_read_exception)
@ -494,11 +509,12 @@ class APIConnection:
async def send_message_await_response( async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0 self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
) -> Any: ) -> Any:
def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type)
res = await self.send_message_await_response_complex( res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout send_msg,
lambda msg: True, # we will only get responses of `response_type`
lambda msg: True, # we will only get responses of `response_type`
(response_type,),
timeout=timeout,
) )
if len(res) != 1: if len(res) != 1:
raise APIConnectionError(f"Expected one result, got {len(res)}") raise APIConnectionError(f"Expected one result, got {len(res)}")
@ -531,12 +547,14 @@ class APIConnection:
# Socket closed but task isn't cancelled yet # Socket closed but task isn't cancelled yet
break break
msg_type = pkt.type msg_type_proto = pkt.type
if msg_type not in MESSAGE_TYPE_TO_PROTO: if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type) _LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto
)
continue continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type]() msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try: try:
msg.ParseFromString(pkt.data) msg.ParseFromString(pkt.data)
except Exception as e: except Exception as e:
@ -545,16 +563,18 @@ class APIConnection:
) )
raise raise
msg_type = type(msg)
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, type(msg), msg "%s: Got message of type %s: %s", self.log_name, msg_type, msg
) )
for handler in self._message_handlers[:]: for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg) handler(msg)
# Pre-check the message type to avoid awaiting # Pre-check the message type to avoid awaiting
# since most messages are not internal messages # since most messages are not internal messages
if isinstance(msg, INTERNAL_MESSAGE_TYPES): if msg_type in INTERNAL_MESSAGE_TYPES:
await self._handle_internal_messages(msg) await self._handle_internal_messages(msg)
async def _read_loop(self) -> None: async def _read_loop(self) -> None:

View File

@ -78,8 +78,7 @@ class APIModelBase:
@classmethod @classmethod
def from_pb(cls: Type[_V], data: Any) -> _V: def from_pb(cls: Type[_V], data: Any) -> _V:
init_args = {f.name: getattr(data, f.name) for f in fields(cls)} return cls(**{f.name: getattr(data, f.name) for f in fields(cls)})
return cls(**init_args)
def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V: def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V:

View File

@ -57,7 +57,7 @@ def auth_client():
def patch_response_complex(client: APIClient, messages): def patch_response_complex(client: APIClient, messages):
async def patched(req, app, stop, timeout=5.0): async def patched(req, app, stop, msg_types, timeout=5.0):
resp = [] resp = []
for msg in messages: for msg in messages:
if app(msg): if app(msg):
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
def patch_response_callback(client: APIClient): def patch_response_callback(client: APIClient):
on_message = None on_message = None
async def patched(req, callback): async def patched(req, callback, msg_types):
nonlocal on_message nonlocal on_message
on_message = callback on_message = callback