mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-25 17:17:42 +01:00
Reduce overhead to send messages that need to wait for responses (#479)
This commit is contained in:
parent
bbfa761aa0
commit
f8ffa6ae83
@ -4,8 +4,8 @@ import enum
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -170,7 +170,7 @@ class APIConnection:
|
||||
self._connect_complete = False
|
||||
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
|
||||
self._message_handlers: Dict[Any, Set[Callable[[message.Message], None]]] = {}
|
||||
# The friendly name to show for this connection in the logs
|
||||
self.log_name = log_name or params.address
|
||||
|
||||
@ -481,11 +481,11 @@ class APIConnection:
|
||||
|
||||
async def login(self, check_connected: bool = True) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
if check_connected:
|
||||
if check_connected and self._connection_state != ConnectionState.CONNECTED:
|
||||
# On first connect, we don't want to check if we're connected
|
||||
# because we don't set the connection state until after login
|
||||
# is complete
|
||||
self._check_connected()
|
||||
raise APIConnectionError("Must be connected!")
|
||||
if self._is_authenticated:
|
||||
raise APIConnectionError("Already logged in!")
|
||||
|
||||
@ -509,10 +509,6 @@ class APIConnection:
|
||||
|
||||
self._is_authenticated = True
|
||||
|
||||
def _check_connected(self) -> None:
|
||||
if self._connection_state != ConnectionState.CONNECTED:
|
||||
raise APIConnectionError("Must be connected!")
|
||||
|
||||
@property
|
||||
def _is_socket_open(self) -> bool:
|
||||
return self._connection_state in (
|
||||
@ -530,7 +526,10 @@ class APIConnection:
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._is_socket_open:
|
||||
if self._connection_state not in (
|
||||
ConnectionState.SOCKET_OPENED,
|
||||
ConnectionState.CONNECTED,
|
||||
):
|
||||
if in_do_connect.get(False):
|
||||
# If we are in the do_connect task, we can't raise an error
|
||||
# because it would obscure the original exception (ie encrypt error).
|
||||
@ -565,21 +564,18 @@ class APIConnection:
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||
) -> Callable[[], None]:
|
||||
"""Add a message callback."""
|
||||
message_handlers = self._message_handlers
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
|
||||
def unsub() -> None:
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
|
||||
return unsub
|
||||
message_handlers.setdefault(msg_type, set()).add(on_message)
|
||||
return partial(self.remove_message_callback, on_message, msg_types)
|
||||
|
||||
def remove_message_callback(
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
|
||||
) -> None:
|
||||
"""Remove a message callback."""
|
||||
message_handlers = self._message_handlers
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
message_handlers[msg_type].discard(on_message)
|
||||
|
||||
def send_message_callback_response(
|
||||
self,
|
||||
@ -594,7 +590,7 @@ class APIConnection:
|
||||
# we can be sure that we will not miss any messages even though
|
||||
# we register the handler after sending the message
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
self._message_handlers.setdefault(msg_type, set()).add(on_message)
|
||||
|
||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
@ -636,10 +632,12 @@ class APIConnection:
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
message_handlers = self._message_handlers
|
||||
read_exception_futures = self._read_exception_futures
|
||||
for msg_type in msg_types:
|
||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||
message_handlers.setdefault(msg_type, set()).add(on_message)
|
||||
|
||||
self._read_exception_futures.add(fut)
|
||||
read_exception_futures.add(fut)
|
||||
# Now safe to await since we have registered the handler
|
||||
|
||||
# We must not await without a finally or
|
||||
@ -655,9 +653,8 @@ class APIConnection:
|
||||
finally:
|
||||
timeout_handle.cancel()
|
||||
for msg_type in msg_types:
|
||||
with suppress(ValueError):
|
||||
self._message_handlers[msg_type].remove(on_message)
|
||||
self._read_exception_futures.discard(fut)
|
||||
message_handlers[msg_type].discard(on_message)
|
||||
read_exception_futures.discard(fut)
|
||||
|
||||
return responses
|
||||
|
||||
@ -766,8 +763,8 @@ class APIConnection:
|
||||
self._send_pending_ping = False
|
||||
|
||||
handlers = message_handlers.get(msg_type)
|
||||
if handlers is not None:
|
||||
for handler in handlers[:]:
|
||||
if handlers:
|
||||
for handler in handlers.copy():
|
||||
handler(msg)
|
||||
|
||||
# Pre-check the message type to avoid awaiting
|
||||
|
Loading…
Reference in New Issue
Block a user