mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-21 16:37:41 +01:00
Callback messages to listeners by type (#328)
This commit is contained in:
parent
056b9a3b79
commit
de5cdfa230
@ -312,16 +312,17 @@ class APIClient:
|
||||
ListEntitiesLockResponse: LockInfo,
|
||||
ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
|
||||
}
|
||||
msg_types = (ListEntitiesDoneResponse, *response_types)
|
||||
|
||||
def do_append(msg: message.Message) -> bool:
|
||||
return isinstance(msg, tuple(response_types.keys()))
|
||||
return not isinstance(msg, ListEntitiesDoneResponse)
|
||||
|
||||
def do_stop(msg: message.Message) -> bool:
|
||||
return isinstance(msg, ListEntitiesDoneResponse)
|
||||
|
||||
assert self._connection is not None
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
ListEntitiesRequest(), do_append, do_stop, timeout=60
|
||||
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
|
||||
)
|
||||
entities: List[EntityInfo] = []
|
||||
services: List[UserService] = []
|
||||
@ -329,19 +330,14 @@ class APIClient:
|
||||
if isinstance(msg, ListEntitiesServicesResponse):
|
||||
services.append(UserService.from_pb(msg))
|
||||
continue
|
||||
cls = None
|
||||
for resp_type, cls in response_types.items():
|
||||
if isinstance(msg, resp_type):
|
||||
break
|
||||
else:
|
||||
continue
|
||||
cls = response_types[type(msg)]
|
||||
assert cls is not None
|
||||
entities.append(cls.from_pb(msg))
|
||||
return entities, services
|
||||
|
||||
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
image_stream: Dict[int, bytes] = {}
|
||||
response_types: Dict[Any, Type[EntityState]] = {
|
||||
BinarySensorStateResponse: BinarySensorState,
|
||||
CoverStateResponse: CoverState,
|
||||
@ -357,31 +353,24 @@ class APIClient:
|
||||
LockStateResponse: LockEntityState,
|
||||
MediaPlayerStateResponse: MediaPlayerEntityState,
|
||||
}
|
||||
|
||||
image_stream: Dict[int, bytes] = {}
|
||||
msg_types = (*response_types, CameraImageResponse)
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, CameraImageResponse):
|
||||
msg_type = type(msg)
|
||||
cls = response_types.get(msg_type)
|
||||
if cls:
|
||||
on_state(cls.from_pb(msg))
|
||||
elif isinstance(msg, CameraImageResponse):
|
||||
data = image_stream.pop(msg.key, bytes()) + msg.data
|
||||
if msg.done:
|
||||
# Return CameraState with the merged data
|
||||
on_state(CameraState(key=msg.key, data=data))
|
||||
else:
|
||||
image_stream[msg.key] = data
|
||||
return
|
||||
|
||||
for resp_type, cls in response_types.items():
|
||||
if isinstance(msg, resp_type):
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
# pylint: disable=undefined-loop-variable
|
||||
on_state(cls.from_pb(msg))
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(
|
||||
SubscribeStatesRequest(), on_msg
|
||||
SubscribeStatesRequest(), on_msg, msg_types
|
||||
)
|
||||
|
||||
async def subscribe_logs(
|
||||
@ -392,8 +381,7 @@ class APIClient:
|
||||
) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, SubscribeLogsResponse):
|
||||
def on_msg(msg: SubscribeLogsResponse) -> None:
|
||||
on_log(msg)
|
||||
|
||||
req = SubscribeLogsRequest()
|
||||
@ -402,20 +390,23 @@ class APIClient:
|
||||
if dump_config is not None:
|
||||
req.dump_config = dump_config
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(req, on_msg)
|
||||
await self._connection.send_message_callback_response(
|
||||
req, on_msg, (SubscribeLogsResponse,)
|
||||
)
|
||||
|
||||
async def subscribe_service_calls(
|
||||
self, on_service_call: Callable[[HomeassistantServiceCall], None]
|
||||
) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, HomeassistantServiceResponse):
|
||||
def on_msg(msg: HomeassistantServiceResponse) -> None:
|
||||
on_service_call(HomeassistantServiceCall.from_pb(msg))
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(
|
||||
SubscribeHomeassistantServicesRequest(), on_msg
|
||||
SubscribeHomeassistantServicesRequest(),
|
||||
on_msg,
|
||||
(HomeassistantServiceResponse,),
|
||||
)
|
||||
|
||||
async def _send_bluetooth_message_await_response(
|
||||
@ -427,17 +418,18 @@ class APIClient:
|
||||
timeout: float = 10.0,
|
||||
) -> message.Message:
|
||||
self._check_authenticated()
|
||||
msg_types = (response_type, BluetoothGATTErrorResponse)
|
||||
assert self._connection is not None
|
||||
|
||||
def is_response(msg: message.Message) -> bool:
|
||||
return (
|
||||
isinstance(msg, (BluetoothGATTErrorResponse, response_type))
|
||||
isinstance(msg, msg_types)
|
||||
and msg.address == address # type: ignore[union-attr]
|
||||
and msg.handle == handle # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
request, is_response, is_response, timeout=timeout
|
||||
request, is_response, is_response, msg_types, timeout=timeout
|
||||
)
|
||||
|
||||
if isinstance(resp[0], BluetoothGATTErrorResponse):
|
||||
@ -449,19 +441,19 @@ class APIClient:
|
||||
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
|
||||
) -> Callable[[], None]:
|
||||
self._check_authenticated()
|
||||
msg_types = (BluetoothLEAdvertisementResponse,)
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, BluetoothLEAdvertisementResponse):
|
||||
def on_msg(msg: BluetoothLEAdvertisementResponse) -> None:
|
||||
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(
|
||||
SubscribeBluetoothLEAdvertisementsRequest(), on_msg
|
||||
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
|
||||
)
|
||||
|
||||
def unsub() -> None:
|
||||
if self._connection is not None:
|
||||
self._connection.remove_message_callback(on_msg)
|
||||
self._connection.remove_message_callback(on_msg, msg_types)
|
||||
|
||||
return unsub
|
||||
|
||||
@ -469,24 +461,24 @@ class APIClient:
|
||||
self, on_bluetooth_connections_free_update: Callable[[int, int], None]
|
||||
) -> Callable[[], None]:
|
||||
self._check_authenticated()
|
||||
msg_types = (BluetoothConnectionsFreeResponse,)
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, BluetoothConnectionsFreeResponse):
|
||||
def on_msg(msg: BluetoothConnectionsFreeResponse) -> None:
|
||||
resp = BluetoothConnectionsFree.from_pb(msg)
|
||||
on_bluetooth_connections_free_update(resp.free, resp.limit)
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(
|
||||
SubscribeBluetoothConnectionsFreeRequest(), on_msg
|
||||
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
|
||||
)
|
||||
|
||||
def unsub() -> None:
|
||||
if self._connection is not None:
|
||||
self._connection.remove_message_callback(on_msg)
|
||||
self._connection.remove_message_callback(on_msg, msg_types)
|
||||
|
||||
return unsub
|
||||
|
||||
async def bluetooth_device_connect(
|
||||
async def bluetooth_device_connect( # pylint: disable=too-many-locals
|
||||
self,
|
||||
address: int,
|
||||
on_bluetooth_connection_state: Callable[[bool, int, int], None],
|
||||
@ -497,11 +489,11 @@ class APIClient:
|
||||
address_type: Optional[int] = None,
|
||||
) -> Callable[[], None]:
|
||||
self._check_authenticated()
|
||||
msg_types = (BluetoothDeviceConnectionResponse,)
|
||||
|
||||
event = asyncio.Event()
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, BluetoothDeviceConnectionResponse):
|
||||
def on_msg(msg: BluetoothDeviceConnectionResponse) -> None:
|
||||
resp = BluetoothDeviceConnection.from_pb(msg)
|
||||
if address == resp.address:
|
||||
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
|
||||
@ -530,11 +522,12 @@ class APIClient:
|
||||
address_type=address_type or 0,
|
||||
),
|
||||
on_msg,
|
||||
msg_types,
|
||||
)
|
||||
|
||||
def unsub() -> None:
|
||||
if self._connection is not None:
|
||||
self._connection.remove_message_callback(on_msg)
|
||||
self._connection.remove_message_callback(on_msg, msg_types)
|
||||
|
||||
try:
|
||||
try:
|
||||
@ -558,7 +551,7 @@ class APIClient:
|
||||
)
|
||||
try:
|
||||
unsub()
|
||||
except ValueError:
|
||||
except (KeyError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"%s: Bluetooth device connection timed out but already unsubscribed",
|
||||
addr,
|
||||
@ -571,7 +564,7 @@ class APIClient:
|
||||
except asyncio.CancelledError:
|
||||
try:
|
||||
unsub()
|
||||
except ValueError:
|
||||
except (KeyError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"%s: Bluetooth device connection canceled but already unsubscribed",
|
||||
addr,
|
||||
@ -595,29 +588,26 @@ class APIClient:
|
||||
self, address: int
|
||||
) -> ESPHomeBluetoothGATTServices:
|
||||
self._check_authenticated()
|
||||
msg_types = (
|
||||
BluetoothGATTGetServicesResponse,
|
||||
BluetoothGATTGetServicesDoneResponse,
|
||||
BluetoothGATTErrorResponse,
|
||||
)
|
||||
append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
|
||||
stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse)
|
||||
|
||||
def do_append(msg: message.Message) -> bool:
|
||||
return (
|
||||
isinstance(
|
||||
msg, (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
|
||||
)
|
||||
and msg.address == address
|
||||
)
|
||||
return isinstance(msg, append_types) and msg.address == address
|
||||
|
||||
def do_stop(msg: message.Message) -> bool:
|
||||
return (
|
||||
isinstance(
|
||||
msg,
|
||||
(BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse),
|
||||
)
|
||||
and msg.address == address
|
||||
)
|
||||
return isinstance(msg, stop_types) and msg.address == address
|
||||
|
||||
assert self._connection is not None
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
BluetoothGATTGetServicesRequest(address=address),
|
||||
do_append,
|
||||
do_stop,
|
||||
msg_types,
|
||||
timeout=DEFAULT_BLE_TIMEOUT,
|
||||
)
|
||||
services = []
|
||||
@ -740,14 +730,15 @@ class APIClient:
|
||||
BluetoothGATTNotifyResponse,
|
||||
)
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, BluetoothGATTNotifyDataResponse):
|
||||
def on_msg(msg: BluetoothGATTNotifyDataResponse) -> None:
|
||||
notify = BluetoothGATTRead.from_pb(msg)
|
||||
if address == notify.address and handle == notify.handle:
|
||||
on_bluetooth_gatt_notify(handle, bytearray(notify.data))
|
||||
|
||||
assert self._connection is not None
|
||||
remove_callback = self._connection.add_message_callback(on_msg)
|
||||
remove_callback = self._connection.add_message_callback(
|
||||
on_msg, (BluetoothGATTNotifyDataResponse,)
|
||||
)
|
||||
|
||||
async def stop_notify() -> None:
|
||||
if self._connection is None:
|
||||
@ -768,13 +759,14 @@ class APIClient:
|
||||
) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
def on_msg(msg: message.Message) -> None:
|
||||
if isinstance(msg, SubscribeHomeAssistantStateResponse):
|
||||
def on_msg(msg: SubscribeHomeAssistantStateResponse) -> None:
|
||||
on_state_sub(msg.entity_id, msg.attribute)
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(
|
||||
SubscribeHomeAssistantStatesRequest(), on_msg
|
||||
SubscribeHomeAssistantStatesRequest(),
|
||||
on_msg,
|
||||
(SubscribeHomeAssistantStateResponse,),
|
||||
)
|
||||
|
||||
async def send_home_assistant_state(
|
||||
|
@ -5,7 +5,7 @@ import socket
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Callable, Coroutine, List, Optional
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
|
||||
|
||||
import async_timeout
|
||||
from google.protobuf import message
|
||||
@ -49,7 +49,9 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
||||
|
||||
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
|
||||
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -98,7 +100,7 @@ class APIConnection:
|
||||
self._connect_complete = False
|
||||
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
|
||||
# The friendly name to show for this connection in the logs
|
||||
self.log_name = params.address
|
||||
|
||||
@ -384,12 +386,9 @@ class APIConnection:
|
||||
if not self._is_socket_open:
|
||||
raise APIConnectionError("Connection isn't established yet")
|
||||
|
||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
||||
if isinstance(msg, klass):
|
||||
break
|
||||
else:
|
||||
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
||||
if not message_type:
|
||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||
|
||||
encoded = msg.SerializeToString()
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
|
||||
@ -412,29 +411,39 @@ class APIConnection:
|
||||
raise
|
||||
|
||||
def add_message_callback(
|
||||
self, on_message: Callable[[Any], None]
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||
) -> Callable[[], None]:
|
||||
"""Add a message callback."""
|
||||
self._message_handlers.append(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
|
||||
def unsub() -> None:
|
||||
self._message_handlers.remove(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
|
||||
return unsub
|
||||
|
||||
def remove_message_callback(self, on_message: Callable[[Any], None]) -> None:
|
||||
def remove_message_callback(
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||
) -> None:
|
||||
"""Remove a message callback."""
|
||||
self._message_handlers.remove(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
|
||||
async def send_message_callback_response(
|
||||
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
||||
self,
|
||||
send_msg: message.Message,
|
||||
on_message: Callable[[Any], None],
|
||||
msg_types: Iterable[Type[Any]],
|
||||
) -> None:
|
||||
"""Send a message to the remote and register the given message handler."""
|
||||
self._message_handlers.append(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
try:
|
||||
await self.send_message(send_msg)
|
||||
except asyncio.CancelledError:
|
||||
self._message_handlers.remove(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
raise
|
||||
|
||||
async def send_message_await_response_complex(
|
||||
@ -442,6 +451,7 @@ class APIConnection:
|
||||
send_msg: message.Message,
|
||||
do_append: Callable[[message.Message], bool],
|
||||
do_stop: Callable[[message.Message], bool],
|
||||
msg_types: Iterable[Type[Any]],
|
||||
timeout: float = 10.0,
|
||||
) -> List[message.Message]:
|
||||
"""Send a message to the remote and build up a list response.
|
||||
@ -472,11 +482,15 @@ class APIConnection:
|
||||
new_exc.__cause__ = exc
|
||||
fut.set_exception(new_exc)
|
||||
|
||||
self._message_handlers.append(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
self._read_exception_handlers.append(on_read_exception)
|
||||
await self.send_message(send_msg)
|
||||
# We must not await without a finally or
|
||||
# the message could fail to be removed if the
|
||||
# the await is cancelled
|
||||
|
||||
try:
|
||||
await self.send_message(send_msg)
|
||||
async with async_timeout.timeout(timeout):
|
||||
await fut
|
||||
except asyncio.TimeoutError as err:
|
||||
@ -485,7 +499,8 @@ class APIConnection:
|
||||
) from err
|
||||
finally:
|
||||
with suppress(ValueError):
|
||||
self._message_handlers.remove(on_message)
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
with suppress(ValueError):
|
||||
self._read_exception_handlers.remove(on_read_exception)
|
||||
|
||||
@ -494,11 +509,12 @@ class APIConnection:
|
||||
async def send_message_await_response(
|
||||
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
||||
) -> Any:
|
||||
def is_response(msg: message.Message) -> bool:
|
||||
return isinstance(msg, response_type)
|
||||
|
||||
res = await self.send_message_await_response_complex(
|
||||
send_msg, is_response, is_response, timeout=timeout
|
||||
send_msg,
|
||||
lambda msg: True, # we will only get responses of `response_type`
|
||||
lambda msg: True, # we will only get responses of `response_type`
|
||||
(response_type,),
|
||||
timeout=timeout,
|
||||
)
|
||||
if len(res) != 1:
|
||||
raise APIConnectionError(f"Expected one result, got {len(res)}")
|
||||
@ -531,12 +547,14 @@ class APIConnection:
|
||||
# Socket closed but task isn't cancelled yet
|
||||
break
|
||||
|
||||
msg_type = pkt.type
|
||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
||||
msg_type_proto = pkt.type
|
||||
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s", self.log_name, msg_type_proto
|
||||
)
|
||||
continue
|
||||
|
||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
||||
try:
|
||||
msg.ParseFromString(pkt.data)
|
||||
except Exception as e:
|
||||
@ -545,16 +563,18 @@ class APIConnection:
|
||||
)
|
||||
raise
|
||||
|
||||
msg_type = type(msg)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
|
||||
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
|
||||
)
|
||||
|
||||
for handler in self._message_handlers[:]:
|
||||
for handler in self._message_handlers.get(msg_type, [])[:]:
|
||||
handler(msg)
|
||||
|
||||
# Pre-check the message type to avoid awaiting
|
||||
# since most messages are not internal messages
|
||||
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
|
||||
if msg_type in INTERNAL_MESSAGE_TYPES:
|
||||
await self._handle_internal_messages(msg)
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
|
@ -78,8 +78,7 @@ class APIModelBase:
|
||||
|
||||
@classmethod
|
||||
def from_pb(cls: Type[_V], data: Any) -> _V:
|
||||
init_args = {f.name: getattr(data, f.name) for f in fields(cls)}
|
||||
return cls(**init_args)
|
||||
return cls(**{f.name: getattr(data, f.name) for f in fields(cls)})
|
||||
|
||||
|
||||
def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V:
|
||||
|
@ -57,7 +57,7 @@ def auth_client():
|
||||
|
||||
|
||||
def patch_response_complex(client: APIClient, messages):
|
||||
async def patched(req, app, stop, timeout=5.0):
|
||||
async def patched(req, app, stop, msg_types, timeout=5.0):
|
||||
resp = []
|
||||
for msg in messages:
|
||||
if app(msg):
|
||||
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
|
||||
def patch_response_callback(client: APIClient):
|
||||
on_message = None
|
||||
|
||||
async def patched(req, callback):
|
||||
async def patched(req, callback, msg_types):
|
||||
nonlocal on_message
|
||||
on_message = callback
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user