mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-21 16:37:41 +01:00
Callback messages to listeners by type (#328)
This commit is contained in:
parent
056b9a3b79
commit
de5cdfa230
@ -312,16 +312,17 @@ class APIClient:
|
|||||||
ListEntitiesLockResponse: LockInfo,
|
ListEntitiesLockResponse: LockInfo,
|
||||||
ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
|
ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
|
||||||
}
|
}
|
||||||
|
msg_types = (ListEntitiesDoneResponse, *response_types)
|
||||||
|
|
||||||
def do_append(msg: message.Message) -> bool:
|
def do_append(msg: message.Message) -> bool:
|
||||||
return isinstance(msg, tuple(response_types.keys()))
|
return not isinstance(msg, ListEntitiesDoneResponse)
|
||||||
|
|
||||||
def do_stop(msg: message.Message) -> bool:
|
def do_stop(msg: message.Message) -> bool:
|
||||||
return isinstance(msg, ListEntitiesDoneResponse)
|
return isinstance(msg, ListEntitiesDoneResponse)
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
resp = await self._connection.send_message_await_response_complex(
|
resp = await self._connection.send_message_await_response_complex(
|
||||||
ListEntitiesRequest(), do_append, do_stop, timeout=60
|
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
|
||||||
)
|
)
|
||||||
entities: List[EntityInfo] = []
|
entities: List[EntityInfo] = []
|
||||||
services: List[UserService] = []
|
services: List[UserService] = []
|
||||||
@ -329,19 +330,14 @@ class APIClient:
|
|||||||
if isinstance(msg, ListEntitiesServicesResponse):
|
if isinstance(msg, ListEntitiesServicesResponse):
|
||||||
services.append(UserService.from_pb(msg))
|
services.append(UserService.from_pb(msg))
|
||||||
continue
|
continue
|
||||||
cls = None
|
cls = response_types[type(msg)]
|
||||||
for resp_type, cls in response_types.items():
|
|
||||||
if isinstance(msg, resp_type):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
assert cls is not None
|
assert cls is not None
|
||||||
entities.append(cls.from_pb(msg))
|
entities.append(cls.from_pb(msg))
|
||||||
return entities, services
|
return entities, services
|
||||||
|
|
||||||
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
|
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
image_stream: Dict[int, bytes] = {}
|
||||||
response_types: Dict[Any, Type[EntityState]] = {
|
response_types: Dict[Any, Type[EntityState]] = {
|
||||||
BinarySensorStateResponse: BinarySensorState,
|
BinarySensorStateResponse: BinarySensorState,
|
||||||
CoverStateResponse: CoverState,
|
CoverStateResponse: CoverState,
|
||||||
@ -357,31 +353,24 @@ class APIClient:
|
|||||||
LockStateResponse: LockEntityState,
|
LockStateResponse: LockEntityState,
|
||||||
MediaPlayerStateResponse: MediaPlayerEntityState,
|
MediaPlayerStateResponse: MediaPlayerEntityState,
|
||||||
}
|
}
|
||||||
|
msg_types = (*response_types, CameraImageResponse)
|
||||||
image_stream: Dict[int, bytes] = {}
|
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: message.Message) -> None:
|
||||||
if isinstance(msg, CameraImageResponse):
|
msg_type = type(msg)
|
||||||
|
cls = response_types.get(msg_type)
|
||||||
|
if cls:
|
||||||
|
on_state(cls.from_pb(msg))
|
||||||
|
elif isinstance(msg, CameraImageResponse):
|
||||||
data = image_stream.pop(msg.key, bytes()) + msg.data
|
data = image_stream.pop(msg.key, bytes()) + msg.data
|
||||||
if msg.done:
|
if msg.done:
|
||||||
# Return CameraState with the merged data
|
# Return CameraState with the merged data
|
||||||
on_state(CameraState(key=msg.key, data=data))
|
on_state(CameraState(key=msg.key, data=data))
|
||||||
else:
|
else:
|
||||||
image_stream[msg.key] = data
|
image_stream[msg.key] = data
|
||||||
return
|
|
||||||
|
|
||||||
for resp_type, cls in response_types.items():
|
|
||||||
if isinstance(msg, resp_type):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
# pylint: disable=undefined-loop-variable
|
|
||||||
on_state(cls.from_pb(msg))
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
await self._connection.send_message_callback_response(
|
||||||
SubscribeStatesRequest(), on_msg
|
SubscribeStatesRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_logs(
|
async def subscribe_logs(
|
||||||
@ -392,9 +381,8 @@ class APIClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: SubscribeLogsResponse) -> None:
|
||||||
if isinstance(msg, SubscribeLogsResponse):
|
on_log(msg)
|
||||||
on_log(msg)
|
|
||||||
|
|
||||||
req = SubscribeLogsRequest()
|
req = SubscribeLogsRequest()
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
@ -402,20 +390,23 @@ class APIClient:
|
|||||||
if dump_config is not None:
|
if dump_config is not None:
|
||||||
req.dump_config = dump_config
|
req.dump_config = dump_config
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(req, on_msg)
|
await self._connection.send_message_callback_response(
|
||||||
|
req, on_msg, (SubscribeLogsResponse,)
|
||||||
|
)
|
||||||
|
|
||||||
async def subscribe_service_calls(
|
async def subscribe_service_calls(
|
||||||
self, on_service_call: Callable[[HomeassistantServiceCall], None]
|
self, on_service_call: Callable[[HomeassistantServiceCall], None]
|
||||||
) -> None:
|
) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: HomeassistantServiceResponse) -> None:
|
||||||
if isinstance(msg, HomeassistantServiceResponse):
|
on_service_call(HomeassistantServiceCall.from_pb(msg))
|
||||||
on_service_call(HomeassistantServiceCall.from_pb(msg))
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
await self._connection.send_message_callback_response(
|
||||||
SubscribeHomeassistantServicesRequest(), on_msg
|
SubscribeHomeassistantServicesRequest(),
|
||||||
|
on_msg,
|
||||||
|
(HomeassistantServiceResponse,),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_bluetooth_message_await_response(
|
async def _send_bluetooth_message_await_response(
|
||||||
@ -427,17 +418,18 @@ class APIClient:
|
|||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
) -> message.Message:
|
) -> message.Message:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
msg_types = (response_type, BluetoothGATTErrorResponse)
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
|
|
||||||
def is_response(msg: message.Message) -> bool:
|
def is_response(msg: message.Message) -> bool:
|
||||||
return (
|
return (
|
||||||
isinstance(msg, (BluetoothGATTErrorResponse, response_type))
|
isinstance(msg, msg_types)
|
||||||
and msg.address == address # type: ignore[union-attr]
|
and msg.address == address # type: ignore[union-attr]
|
||||||
and msg.handle == handle # type: ignore[union-attr]
|
and msg.handle == handle # type: ignore[union-attr]
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = await self._connection.send_message_await_response_complex(
|
resp = await self._connection.send_message_await_response_complex(
|
||||||
request, is_response, is_response, timeout=timeout
|
request, is_response, is_response, msg_types, timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(resp[0], BluetoothGATTErrorResponse):
|
if isinstance(resp[0], BluetoothGATTErrorResponse):
|
||||||
@ -449,19 +441,19 @@ class APIClient:
|
|||||||
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
|
self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None]
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
msg_types = (BluetoothLEAdvertisementResponse,)
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: BluetoothLEAdvertisementResponse) -> None:
|
||||||
if isinstance(msg, BluetoothLEAdvertisementResponse):
|
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
|
||||||
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg))
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
await self._connection.send_message_callback_response(
|
||||||
SubscribeBluetoothLEAdvertisementsRequest(), on_msg
|
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
def unsub() -> None:
|
def unsub() -> None:
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
self._connection.remove_message_callback(on_msg)
|
self._connection.remove_message_callback(on_msg, msg_types)
|
||||||
|
|
||||||
return unsub
|
return unsub
|
||||||
|
|
||||||
@ -469,24 +461,24 @@ class APIClient:
|
|||||||
self, on_bluetooth_connections_free_update: Callable[[int, int], None]
|
self, on_bluetooth_connections_free_update: Callable[[int, int], None]
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
msg_types = (BluetoothConnectionsFreeResponse,)
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: BluetoothConnectionsFreeResponse) -> None:
|
||||||
if isinstance(msg, BluetoothConnectionsFreeResponse):
|
resp = BluetoothConnectionsFree.from_pb(msg)
|
||||||
resp = BluetoothConnectionsFree.from_pb(msg)
|
on_bluetooth_connections_free_update(resp.free, resp.limit)
|
||||||
on_bluetooth_connections_free_update(resp.free, resp.limit)
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
await self._connection.send_message_callback_response(
|
||||||
SubscribeBluetoothConnectionsFreeRequest(), on_msg
|
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
def unsub() -> None:
|
def unsub() -> None:
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
self._connection.remove_message_callback(on_msg)
|
self._connection.remove_message_callback(on_msg, msg_types)
|
||||||
|
|
||||||
return unsub
|
return unsub
|
||||||
|
|
||||||
async def bluetooth_device_connect(
|
async def bluetooth_device_connect( # pylint: disable=too-many-locals
|
||||||
self,
|
self,
|
||||||
address: int,
|
address: int,
|
||||||
on_bluetooth_connection_state: Callable[[bool, int, int], None],
|
on_bluetooth_connection_state: Callable[[bool, int, int], None],
|
||||||
@ -497,15 +489,15 @@ class APIClient:
|
|||||||
address_type: Optional[int] = None,
|
address_type: Optional[int] = None,
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
msg_types = (BluetoothDeviceConnectionResponse,)
|
||||||
|
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: BluetoothDeviceConnectionResponse) -> None:
|
||||||
if isinstance(msg, BluetoothDeviceConnectionResponse):
|
resp = BluetoothDeviceConnection.from_pb(msg)
|
||||||
resp = BluetoothDeviceConnection.from_pb(msg)
|
if address == resp.address:
|
||||||
if address == resp.address:
|
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
|
||||||
on_bluetooth_connection_state(resp.connected, resp.mtu, resp.error)
|
event.set()
|
||||||
event.set()
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
if has_cache:
|
if has_cache:
|
||||||
@ -530,11 +522,12 @@ class APIClient:
|
|||||||
address_type=address_type or 0,
|
address_type=address_type or 0,
|
||||||
),
|
),
|
||||||
on_msg,
|
on_msg,
|
||||||
|
msg_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
def unsub() -> None:
|
def unsub() -> None:
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
self._connection.remove_message_callback(on_msg)
|
self._connection.remove_message_callback(on_msg, msg_types)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@ -558,7 +551,7 @@ class APIClient:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
unsub()
|
unsub()
|
||||||
except ValueError:
|
except (KeyError, ValueError):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"%s: Bluetooth device connection timed out but already unsubscribed",
|
"%s: Bluetooth device connection timed out but already unsubscribed",
|
||||||
addr,
|
addr,
|
||||||
@ -571,7 +564,7 @@ class APIClient:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
try:
|
try:
|
||||||
unsub()
|
unsub()
|
||||||
except ValueError:
|
except (KeyError, ValueError):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"%s: Bluetooth device connection canceled but already unsubscribed",
|
"%s: Bluetooth device connection canceled but already unsubscribed",
|
||||||
addr,
|
addr,
|
||||||
@ -595,29 +588,26 @@ class APIClient:
|
|||||||
self, address: int
|
self, address: int
|
||||||
) -> ESPHomeBluetoothGATTServices:
|
) -> ESPHomeBluetoothGATTServices:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
msg_types = (
|
||||||
|
BluetoothGATTGetServicesResponse,
|
||||||
|
BluetoothGATTGetServicesDoneResponse,
|
||||||
|
BluetoothGATTErrorResponse,
|
||||||
|
)
|
||||||
|
append_types = (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
|
||||||
|
stop_types = (BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse)
|
||||||
|
|
||||||
def do_append(msg: message.Message) -> bool:
|
def do_append(msg: message.Message) -> bool:
|
||||||
return (
|
return isinstance(msg, append_types) and msg.address == address
|
||||||
isinstance(
|
|
||||||
msg, (BluetoothGATTGetServicesResponse, BluetoothGATTErrorResponse)
|
|
||||||
)
|
|
||||||
and msg.address == address
|
|
||||||
)
|
|
||||||
|
|
||||||
def do_stop(msg: message.Message) -> bool:
|
def do_stop(msg: message.Message) -> bool:
|
||||||
return (
|
return isinstance(msg, stop_types) and msg.address == address
|
||||||
isinstance(
|
|
||||||
msg,
|
|
||||||
(BluetoothGATTGetServicesDoneResponse, BluetoothGATTErrorResponse),
|
|
||||||
)
|
|
||||||
and msg.address == address
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
resp = await self._connection.send_message_await_response_complex(
|
resp = await self._connection.send_message_await_response_complex(
|
||||||
BluetoothGATTGetServicesRequest(address=address),
|
BluetoothGATTGetServicesRequest(address=address),
|
||||||
do_append,
|
do_append,
|
||||||
do_stop,
|
do_stop,
|
||||||
|
msg_types,
|
||||||
timeout=DEFAULT_BLE_TIMEOUT,
|
timeout=DEFAULT_BLE_TIMEOUT,
|
||||||
)
|
)
|
||||||
services = []
|
services = []
|
||||||
@ -740,14 +730,15 @@ class APIClient:
|
|||||||
BluetoothGATTNotifyResponse,
|
BluetoothGATTNotifyResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: BluetoothGATTNotifyDataResponse) -> None:
|
||||||
if isinstance(msg, BluetoothGATTNotifyDataResponse):
|
notify = BluetoothGATTRead.from_pb(msg)
|
||||||
notify = BluetoothGATTRead.from_pb(msg)
|
if address == notify.address and handle == notify.handle:
|
||||||
if address == notify.address and handle == notify.handle:
|
on_bluetooth_gatt_notify(handle, bytearray(notify.data))
|
||||||
on_bluetooth_gatt_notify(handle, bytearray(notify.data))
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
remove_callback = self._connection.add_message_callback(on_msg)
|
remove_callback = self._connection.add_message_callback(
|
||||||
|
on_msg, (BluetoothGATTNotifyDataResponse,)
|
||||||
|
)
|
||||||
|
|
||||||
async def stop_notify() -> None:
|
async def stop_notify() -> None:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
@ -768,13 +759,14 @@ class APIClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
def on_msg(msg: message.Message) -> None:
|
def on_msg(msg: SubscribeHomeAssistantStateResponse) -> None:
|
||||||
if isinstance(msg, SubscribeHomeAssistantStateResponse):
|
on_state_sub(msg.entity_id, msg.attribute)
|
||||||
on_state_sub(msg.entity_id, msg.attribute)
|
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
await self._connection.send_message_callback_response(
|
||||||
SubscribeHomeAssistantStatesRequest(), on_msg
|
SubscribeHomeAssistantStatesRequest(),
|
||||||
|
on_msg,
|
||||||
|
(SubscribeHomeAssistantStateResponse,),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_home_assistant_state(
|
async def send_home_assistant_state(
|
||||||
|
@ -5,7 +5,7 @@ import socket
|
|||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from typing import Any, Callable, Coroutine, List, Optional
|
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -49,7 +49,9 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
||||||
|
|
||||||
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
|
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
||||||
|
|
||||||
|
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -98,7 +100,7 @@ class APIConnection:
|
|||||||
self._connect_complete = False
|
self._connect_complete = False
|
||||||
|
|
||||||
# Message handlers currently subscribed to incoming messages
|
# Message handlers currently subscribed to incoming messages
|
||||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
|
||||||
# The friendly name to show for this connection in the logs
|
# The friendly name to show for this connection in the logs
|
||||||
self.log_name = params.address
|
self.log_name = params.address
|
||||||
|
|
||||||
@ -384,12 +386,9 @@ class APIConnection:
|
|||||||
if not self._is_socket_open:
|
if not self._is_socket_open:
|
||||||
raise APIConnectionError("Connection isn't established yet")
|
raise APIConnectionError("Connection isn't established yet")
|
||||||
|
|
||||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
||||||
if isinstance(msg, klass):
|
if not message_type:
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||||
|
|
||||||
encoded = msg.SerializeToString()
|
encoded = msg.SerializeToString()
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||||
|
|
||||||
@ -412,29 +411,39 @@ class APIConnection:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def add_message_callback(
|
def add_message_callback(
|
||||||
self, on_message: Callable[[Any], None]
|
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Add a message callback."""
|
"""Add a message callback."""
|
||||||
self._message_handlers.append(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||||
|
|
||||||
def unsub() -> None:
|
def unsub() -> None:
|
||||||
self._message_handlers.remove(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
|
|
||||||
return unsub
|
return unsub
|
||||||
|
|
||||||
def remove_message_callback(self, on_message: Callable[[Any], None]) -> None:
|
def remove_message_callback(
|
||||||
|
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||||
|
) -> None:
|
||||||
"""Remove a message callback."""
|
"""Remove a message callback."""
|
||||||
self._message_handlers.remove(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
|
|
||||||
async def send_message_callback_response(
|
async def send_message_callback_response(
|
||||||
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
self,
|
||||||
|
send_msg: message.Message,
|
||||||
|
on_message: Callable[[Any], None],
|
||||||
|
msg_types: Iterable[Type[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send a message to the remote and register the given message handler."""
|
"""Send a message to the remote and register the given message handler."""
|
||||||
self._message_handlers.append(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||||
try:
|
try:
|
||||||
await self.send_message(send_msg)
|
await self.send_message(send_msg)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self._message_handlers.remove(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def send_message_await_response_complex(
|
async def send_message_await_response_complex(
|
||||||
@ -442,6 +451,7 @@ class APIConnection:
|
|||||||
send_msg: message.Message,
|
send_msg: message.Message,
|
||||||
do_append: Callable[[message.Message], bool],
|
do_append: Callable[[message.Message], bool],
|
||||||
do_stop: Callable[[message.Message], bool],
|
do_stop: Callable[[message.Message], bool],
|
||||||
|
msg_types: Iterable[Type[Any]],
|
||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
) -> List[message.Message]:
|
) -> List[message.Message]:
|
||||||
"""Send a message to the remote and build up a list response.
|
"""Send a message to the remote and build up a list response.
|
||||||
@ -472,11 +482,15 @@ class APIConnection:
|
|||||||
new_exc.__cause__ = exc
|
new_exc.__cause__ = exc
|
||||||
fut.set_exception(new_exc)
|
fut.set_exception(new_exc)
|
||||||
|
|
||||||
self._message_handlers.append(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||||
self._read_exception_handlers.append(on_read_exception)
|
self._read_exception_handlers.append(on_read_exception)
|
||||||
await self.send_message(send_msg)
|
# We must not await without a finally or
|
||||||
|
# the message could fail to be removed if the
|
||||||
|
# the await is cancelled
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await self.send_message(send_msg)
|
||||||
async with async_timeout.timeout(timeout):
|
async with async_timeout.timeout(timeout):
|
||||||
await fut
|
await fut
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
@ -485,7 +499,8 @@ class APIConnection:
|
|||||||
) from err
|
) from err
|
||||||
finally:
|
finally:
|
||||||
with suppress(ValueError):
|
with suppress(ValueError):
|
||||||
self._message_handlers.remove(on_message)
|
for msg_type in msg_types:
|
||||||
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
with suppress(ValueError):
|
with suppress(ValueError):
|
||||||
self._read_exception_handlers.remove(on_read_exception)
|
self._read_exception_handlers.remove(on_read_exception)
|
||||||
|
|
||||||
@ -494,11 +509,12 @@ class APIConnection:
|
|||||||
async def send_message_await_response(
|
async def send_message_await_response(
|
||||||
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
||||||
) -> Any:
|
) -> Any:
|
||||||
def is_response(msg: message.Message) -> bool:
|
|
||||||
return isinstance(msg, response_type)
|
|
||||||
|
|
||||||
res = await self.send_message_await_response_complex(
|
res = await self.send_message_await_response_complex(
|
||||||
send_msg, is_response, is_response, timeout=timeout
|
send_msg,
|
||||||
|
lambda msg: True, # we will only get responses of `response_type`
|
||||||
|
lambda msg: True, # we will only get responses of `response_type`
|
||||||
|
(response_type,),
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
if len(res) != 1:
|
if len(res) != 1:
|
||||||
raise APIConnectionError(f"Expected one result, got {len(res)}")
|
raise APIConnectionError(f"Expected one result, got {len(res)}")
|
||||||
@ -531,12 +547,14 @@ class APIConnection:
|
|||||||
# Socket closed but task isn't cancelled yet
|
# Socket closed but task isn't cancelled yet
|
||||||
break
|
break
|
||||||
|
|
||||||
msg_type = pkt.type
|
msg_type_proto = pkt.type
|
||||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
|
||||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
_LOGGER.debug(
|
||||||
|
"%s: Skipping message type %s", self.log_name, msg_type_proto
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
||||||
try:
|
try:
|
||||||
msg.ParseFromString(pkt.data)
|
msg.ParseFromString(pkt.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -545,16 +563,18 @@ class APIConnection:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
msg_type = type(msg)
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
|
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
|
||||||
)
|
)
|
||||||
|
|
||||||
for handler in self._message_handlers[:]:
|
for handler in self._message_handlers.get(msg_type, [])[:]:
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
|
||||||
# Pre-check the message type to avoid awaiting
|
# Pre-check the message type to avoid awaiting
|
||||||
# since most messages are not internal messages
|
# since most messages are not internal messages
|
||||||
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
|
if msg_type in INTERNAL_MESSAGE_TYPES:
|
||||||
await self._handle_internal_messages(msg)
|
await self._handle_internal_messages(msg)
|
||||||
|
|
||||||
async def _read_loop(self) -> None:
|
async def _read_loop(self) -> None:
|
||||||
|
@ -78,8 +78,7 @@ class APIModelBase:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(cls: Type[_V], data: Any) -> _V:
|
def from_pb(cls: Type[_V], data: Any) -> _V:
|
||||||
init_args = {f.name: getattr(data, f.name) for f in fields(cls)}
|
return cls(**{f.name: getattr(data, f.name) for f in fields(cls)})
|
||||||
return cls(**init_args)
|
|
||||||
|
|
||||||
|
|
||||||
def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V:
|
def converter_field(*, converter: Callable[[Any], _V], **kwargs: Any) -> _V:
|
||||||
|
@ -57,7 +57,7 @@ def auth_client():
|
|||||||
|
|
||||||
|
|
||||||
def patch_response_complex(client: APIClient, messages):
|
def patch_response_complex(client: APIClient, messages):
|
||||||
async def patched(req, app, stop, timeout=5.0):
|
async def patched(req, app, stop, msg_types, timeout=5.0):
|
||||||
resp = []
|
resp = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if app(msg):
|
if app(msg):
|
||||||
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
|
|||||||
def patch_response_callback(client: APIClient):
|
def patch_response_callback(client: APIClient):
|
||||||
on_message = None
|
on_message = None
|
||||||
|
|
||||||
async def patched(req, callback):
|
async def patched(req, callback, msg_types):
|
||||||
nonlocal on_message
|
nonlocal on_message
|
||||||
on_message = callback
|
on_message = callback
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user