Callback messages to listeners by type (#328)

This commit is contained in:
J. Nick Koston 2022-12-02 09:36:58 -10:00 committed by GitHub
parent 056b9a3b79
commit de5cdfa230
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 111 deletions

View File

@ -312,16 +312,17 @@ class APIClient:
ListEntitiesLockResponse: LockInfo,
ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
}
msg_types = (ListEntitiesDoneResponse, *response_types)
def do_append(msg: message.Message) -> bool:
return isinstance(msg, tuple(response_types.keys()))
return not isinstance(msg, ListEntitiesDoneResponse)
def do_stop(msg: message.Message) -> bool:
return isinstance(msg, ListEntitiesDoneResponse)
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex(
ListEntitiesRequest(), do_append, do_stop, timeout=60
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
)
entities: List[EntityInfo] = []
services: List[UserService] = []
@ -329,19 +330,14 @@ class APIClient:
if isinstance(msg, ListEntitiesServicesResponse):
services.append(UserService.from_pb(msg))
continue
cls = None
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
continue
cls = response_types[type(msg)]
assert cls is not None
entities.append(cls.from_pb(msg))
return entities, services
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
self._check_authenticated()
image_stream: Dict[int, bytes] = {}
response_types: Dict[Any, Type[EntityState]] = {
BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState,
@ -357,31 +353,24 @@ class APIClient:
LockStateResponse: LockEntityState,
MediaPlayerStateResponse: MediaPlayerEntityState,
}
image_stream: Dict[int, bytes] = {}
msg_types = (*response_types, CameraImageResponse)
def on_msg(msg: message.Message) -> None:
if isinstance(msg, CameraImageResponse):
msg_type = type(msg)
cls = response_types.get(msg_type)
if cls:
on_state(cls.from_pb(msg))
elif isinstance(msg, CameraImageResponse):
data = image_stream.pop(msg.key, bytes()) + msg.data
if msg.done:
# Return CameraState with the merged data
on_state(CameraState(key=msg.key, data=data))
else:
image_stream[msg.key] = data
return
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
return
# pylint: disable=undefined-loop-variable
on_state(cls.from_pb(msg))
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg
SubscribeStatesRequest(), on_msg, msg_types
)
async def subscribe_logs(
@ -392,8 +381,7 @@ class APIClient:
) -> None:
self._check_authenticated()
def on_msg(msg: message.Message) -> None:
if isinstance(msg, SubscribeLogsResponse):
def on_msg(msg: SubscribeLogsResponse) -> None:
on_log(msg)
req = SubscribeLogsRequest()
@ -402,20 +390,23 @@ class APIClient:
if dump_config is not None:
req.dump_config = dump_config
assert self._connection is not None
await self._connection.send_message_callback_response(req, on_msg)
await self._connection.send_message_callback_response(
req, on_msg, (SubscribeLogsResponse,)
)
async def subscribe_service_calls(
self, on_service_call: Callable[[HomeassistantServiceCall], None]
) -> None:
self._check_authenticated()
def on_msg(msg: message.Message) -> None:
if isinstance(msg, HomeassistantServiceResponse):
def on_msg(msg: HomeassistantServiceResponse) -> None:
on_service_call(HomeassistantServiceCall.from_pb(msg))
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeHomeassistantServicesRequest(), on_msg
SubscribeHomeassistantServicesRequest(),
on_msg,
(HomeassistantServiceResponse,),
)
async def _send_bluetooth_message_await_response(
@ -427,17 +418,18 @@ class APIClient:
timeout: float = 10.0,
) -> message.Message:
self._check_authenticated()
msg_types = (response_type, BluetoothGATTErrorResponse)
assert self._connection is not None
def is_response(msg: message.Message) -> bool:
return (
isinstance(msg, (BluetoothGATTErrorResponse, response_type))
isinstance(msg, msg_types)
and msg.address == address # type: ignore[union-attr]
and msg.handle == handle # type: ignore[union-attr]
)
resp = await self._connection.send_message_await_response_complex(
request, is_response, is_response, timeout=timeout
request, is_response, is_response, msg_types, timeout=timeout
)
if isinstance(resp[0], BluetoothGATTErrorResponse):
@ -449,19 +441,19 @@ class APIClient:
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothLEAdvertisementResponse,)
def on_msg(msg: message.Message) -> None:
if isinstance(msg, BluetoothLEAdvertisementResponse):
def on_msg(msg: BluetoothLEAdvertisementResponse) -> None:
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(), on_msg
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
)
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(on_msg)
self._connection.remove_message_callback(on_msg, msg_types)
return unsub
@ -469,24 +461,24 @@ class APIClient:
self, on_bluetooth_connections_free_update: Callable[[int, int], None]
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothConnectionsFreeResponse,)
def on_msg(msg: message.Message) -> None:
if isinstance(msg, BluetoothConnectionsFreeResponse):
def on_msg(msg: BluetoothConnectionsFreeResponse) -> None:
resp = BluetoothConnectionsFree.from_pb(msg)
on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(), on_msg
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
)
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(on_msg)
self._connection.remove_message_callback(on_msg, msg_types)
return unsub
async def bluetooth_device_connect(
async def bluetooth_device_connect( # pylint: disable=too-many-locals
self,
address: int,
on_bluetooth_connection_state: Callable[[bool, int, int], None],
@ -497,11 +489,11 @@ class APIClient:
address_type: Optional[int] = None,
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,)
event = asyncio.Event()
def on_msg(msg: message.Message) -> None:
if isinstance(msg, BluetoothDeviceConnectionResponse):
def on_msg(msg: BluetoothDeviceConnectionResponse) -> None:
resp = BluetoothDeviceConnection.from_pb(msg)
if address == resp.address:
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
@ -530,11 +522,12 @@ class APIClient:
address_type=address_type or 0,
),
on_msg,
msg_types,
)
def unsub() -> None:
if self._connection is not None:
self._connection.remove_message_callback(on_msg)
self._connection.remove_message_callback(on_msg, msg_types)
try:
try:
@ -558,7 +551,7 @@ class APIClient:
)
try:
unsub()
except ValueError:
except (KeyError, ValueError):
_LOGGER.warning(
"%s: Bluetooth device connection timed out but already unsubscribed",
addr,
@ -571,7 +564,7 @@ class APIClient:
except asyncio.CancelledError:
try:
unsub()
except ValueError:
except (KeyError, ValueError):
_LOGGER.warning(
"%s: Bluetooth device connection canceled but already unsubscribed",
addr,
@ -595,29 +588,26 @@ class APIClient:
self, address: int
) -> ESPHomeBluetoothGATTServices:
self._check_authenticated()
msg_types = (
BluetoothGATTGetServicesResponse,
BluetoothGATTGetServicesDoneResponse,
BluetoothGATTErrorResponse,
)
append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse)
def do_append(msg: message.Message) -> bool:
return (
isinstance(
msg, (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
)
and msg.address == address
)
return isinstance(msg, append_types) and msg.address == address
def do_stop(msg: message.Message) -> bool:
return (
isinstance(
msg,
(BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse),
)
and msg.address == address
)
return isinstance(msg, stop_types) and msg.address == address
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex(
BluetoothGATTGetServicesRequest(address=address),
do_append,
do_stop,
msg_types,
timeout=DEFAULT_BLE_TIMEOUT,
)
services = []
@ -740,14 +730,15 @@ class APIClient:
BluetoothGATTNotifyResponse,
)
def on_msg(msg: message.Message) -> None:
if isinstance(msg, BluetoothGATTNotifyDataResponse):
def on_msg(msg: BluetoothGATTNotifyDataResponse) -> None:
notify = BluetoothGATTRead.from_pb(msg)
if address == notify.address and handle == notify.handle:
on_bluetooth_gatt_notify(handle, bytearray(notify.data))
assert self._connection is not None
remove_callback = self._connection.add_message_callback(on_msg)
remove_callback = self._connection.add_message_callback(
on_msg, (BluetoothGATTNotifyDataResponse,)
)
async def stop_notify() -> None:
if self._connection is None:
@ -768,13 +759,14 @@ class APIClient:
) -> None:
self._check_authenticated()
def on_msg(msg: message.Message) -> None:
if isinstance(msg, SubscribeHomeAssistantStateResponse):
def on_msg(msg: SubscribeHomeAssistantStateResponse) -> None:
on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeHomeAssistantStatesRequest(), on_msg
SubscribeHomeAssistantStatesRequest(),
on_msg,
(SubscribeHomeAssistantStateResponse,),
)
async def send_home_assistant_state(

View File

@ -5,7 +5,7 @@ import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, List, Optional
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
import async_timeout
from google.protobuf import message
@ -49,7 +49,9 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@dataclass
@ -98,7 +100,7 @@ class APIConnection:
self._connect_complete = False
# Message handlers currently subscribed to incoming messages
self._message_handlers: List[Callable[[message.Message], None]] = []
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = params.address
@ -384,12 +386,9 @@ class APIConnection:
if not self._is_socket_open:
raise APIConnectionError("Connection isn't established yet")
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
if isinstance(msg, klass):
break
else:
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
if not message_type:
raise ValueError(f"Message type id not found for type {type(msg)}")
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
@ -412,29 +411,39 @@ class APIConnection:
raise
def add_message_callback(
self, on_message: Callable[[Any], None]
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> Callable[[], None]:
"""Add a message callback."""
self._message_handlers.append(on_message)
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
def unsub() -> None:
self._message_handlers.remove(on_message)
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
return unsub
def remove_message_callback(self, on_message: Callable[[Any], None]) -> None:
def remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> None:
"""Remove a message callback."""
self._message_handlers.remove(on_message)
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None]
self,
send_msg: message.Message,
on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]],
) -> None:
"""Send a message to the remote and register the given message handler."""
self._message_handlers.append(on_message)
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
try:
await self.send_message(send_msg)
except asyncio.CancelledError:
self._message_handlers.remove(on_message)
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
raise
async def send_message_await_response_complex(
@ -442,6 +451,7 @@ class APIConnection:
send_msg: message.Message,
do_append: Callable[[message.Message], bool],
do_stop: Callable[[message.Message], bool],
msg_types: Iterable[Type[Any]],
timeout: float = 10.0,
) -> List[message.Message]:
"""Send a message to the remote and build up a list response.
@ -472,11 +482,15 @@ class APIConnection:
new_exc.__cause__ = exc
fut.set_exception(new_exc)
self._message_handlers.append(on_message)
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
self._read_exception_handlers.append(on_read_exception)
await self.send_message(send_msg)
# We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
try:
await self.send_message(send_msg)
async with async_timeout.timeout(timeout):
await fut
except asyncio.TimeoutError as err:
@ -485,7 +499,8 @@ class APIConnection:
) from err
finally:
with suppress(ValueError):
self._message_handlers.remove(on_message)
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
with suppress(ValueError):
self._read_exception_handlers.remove(on_read_exception)
@ -494,11 +509,12 @@ class APIConnection:
async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
) -> Any:
def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type)
res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout
send_msg,
lambda msg: True, # we will only get responses of `response_type`
lambda msg: True, # we will only get responses of `response_type`
(response_type,),
timeout=timeout,
)
if len(res) != 1:
raise APIConnectionError(f"Expected one result, got {len(res)}")
@ -531,12 +547,14 @@ class APIConnection:
# Socket closed but task isn't cancelled yet
break
msg_type = pkt.type
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
msg_type_proto = pkt.type
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto
)
continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try:
msg.ParseFromString(pkt.data)
except Exception as e:
@ -545,16 +563,18 @@ class APIConnection:
)
raise
msg_type = type(msg)
_LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
)
for handler in self._message_handlers[:]:
for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
if msg_type in INTERNAL_MESSAGE_TYPES:
await self._handle_internal_messages(msg)
async def _read_loop(self) -> None:

View File

@ -78,8 +78,7 @@ class APIModelBase:
@classmethod
def from_pb(cls: Type[_V], data: Any) -> _V:
init_args = {f.name: getattr(data, f.name) for f in fields(cls)}
return cls(**init_args)
return cls(**{f.name: getattr(data, f.name) for f in fields(cls)})
def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V:

View File

@ -57,7 +57,7 @@ def auth_client():
def patch_response_complex(client: APIClient, messages):
async def patched(req, app, stop, timeout=5.0):
async def patched(req, app, stop, msg_types, timeout=5.0):
resp = []
for msg in messages:
if app(msg):
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
def patch_response_callback(client: APIClient):
on_message = None
async def patched(req, callback):
async def patched(req, callback, msg_types):
nonlocal on_message
on_message = callback