Reduce overhead to send messages that need to wait for responses (#479)

This commit is contained in:
J. Nick Koston 2023-07-15 09:58:45 -10:00 committed by GitHub
parent bbfa761aa0
commit f8ffa6ae83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 25 deletions

View File

@ -4,8 +4,8 @@ import enum
import logging
import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from functools import partial
from typing import (
Any,
Callable,
@ -170,7 +170,7 @@ class APIConnection:
self._connect_complete = False
# Message handlers currently subscribed to incoming messages
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
self._message_handlers: Dict[Any, Set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
@ -481,11 +481,11 @@ class APIConnection:
async def login(self, check_connected: bool = True) -> None:
"""Send a login (ConnectRequest) and await the response."""
if check_connected:
if check_connected and self._connection_state != ConnectionState.CONNECTED:
# On first connect, we don't want to check if we're connected
# because we don't set the connection state until after login
# is complete
self._check_connected()
raise APIConnectionError("Must be connected!")
if self._is_authenticated:
raise APIConnectionError("Already logged in!")
@ -509,10 +509,6 @@ class APIConnection:
self._is_authenticated = True
def _check_connected(self) -> None:
if self._connection_state != ConnectionState.CONNECTED:
raise APIConnectionError("Must be connected!")
@property
def _is_socket_open(self) -> bool:
return self._connection_state in (
@ -530,7 +526,10 @@ class APIConnection:
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if not self._is_socket_open:
if self._connection_state not in (
ConnectionState.SOCKET_OPENED,
ConnectionState.CONNECTED,
):
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
@ -565,21 +564,18 @@ class APIConnection:
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> Callable[[], None]:
"""Add a message callback."""
message_handlers = self._message_handlers
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
def unsub() -> None:
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
return unsub
message_handlers.setdefault(msg_type, set()).add(on_message)
return partial(self.remove_message_callback, on_message, msg_types)
def remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
) -> None:
"""Remove a message callback."""
message_handlers = self._message_handlers
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
message_handlers[msg_type].discard(on_message)
def send_message_callback_response(
self,
@ -594,7 +590,7 @@ class APIConnection:
# we can be sure that we will not miss any messages even though
# we register the handler after sending the message
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
self._message_handlers.setdefault(msg_type, set()).add(on_message)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
@ -636,10 +632,12 @@ class APIConnection:
if do_stop is None or do_stop(resp):
fut.set_result(None)
message_handlers = self._message_handlers
read_exception_futures = self._read_exception_futures
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
message_handlers.setdefault(msg_type, set()).add(on_message)
self._read_exception_futures.add(fut)
read_exception_futures.add(fut)
# Now safe to await since we have registered the handler
# We must not await without a finally or
@ -655,9 +653,8 @@ class APIConnection:
finally:
timeout_handle.cancel()
for msg_type in msg_types:
with suppress(ValueError):
self._message_handlers[msg_type].remove(on_message)
self._read_exception_futures.discard(fut)
message_handlers[msg_type].discard(on_message)
read_exception_futures.discard(fut)
return responses
@ -766,8 +763,8 @@ class APIConnection:
self._send_pending_ping = False
handlers = message_handlers.get(msg_type)
if handlers is not None:
for handler in handlers[:]:
if handlers:
for handler in handlers.copy():
handler(msg)
# Pre-check the message type to avoid awaiting