Reduce overhead to handle read exceptions (#472)

This commit is contained in:
J. Nick Koston 2023-07-15 08:18:31 -10:00 committed by GitHub
parent 0dfaa58f07
commit ce07e11e93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,7 +6,18 @@ import socket
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union from typing import (
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
Type,
Union,
)
import async_timeout import async_timeout
from google.protobuf import message from google.protobuf import message
@ -127,7 +138,7 @@ class APIConnection:
"_connect_complete", "_connect_complete",
"_message_handlers", "_message_handlers",
"log_name", "log_name",
"_read_exception_handlers", "_read_exception_futures",
"_ping_timer", "_ping_timer",
"_pong_timer", "_pong_timer",
"_keep_alive_interval", "_keep_alive_interval",
@ -163,8 +174,8 @@ class APIConnection:
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = log_name or params.address self.log_name = log_name or params.address
# Handlers currently subscribed to exceptions in the read task # futures currently subscribed to exceptions in the read task
self._read_exception_handlers: List[Callable[[Exception], None]] = [] self._read_exception_futures: Set[asyncio.Future[None]] = set()
self._ping_timer: Optional[asyncio.TimerHandle] = None self._ping_timer: Optional[asyncio.TimerHandle] = None
self._pong_timer: Optional[asyncio.TimerHandle] = None self._pong_timer: Optional[asyncio.TimerHandle] = None
@ -601,7 +612,7 @@ class APIConnection:
:raises TimeoutAPIError: if a timeout occured :raises TimeoutAPIError: if a timeout occured
""" """
fut = self._loop.create_future() fut: asyncio.Future[None] = self._loop.create_future()
responses = [] responses = []
def on_message(resp: message.Message) -> None: def on_message(resp: message.Message) -> None:
@ -610,19 +621,12 @@ class APIConnection:
if do_append(resp): if do_append(resp):
responses.append(resp) responses.append(resp)
if do_stop(resp): if do_stop(resp):
fut.set_result(responses) fut.set_result(None)
def on_read_exception(exc: Exception) -> None:
if not fut.done():
new_exc = exc
if not isinstance(exc, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed")
new_exc.__cause__ = exc
fut.set_exception(new_exc)
for msg_type in msg_types: for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message) self._message_handlers.setdefault(msg_type, []).append(on_message)
self._read_exception_handlers.append(on_read_exception)
self._read_exception_futures.add(fut)
# We must not await without a finally or # We must not await without a finally or
# the message could fail to be removed if the # the message could fail to be removed if the
# the await is cancelled # the await is cancelled
@ -639,8 +643,7 @@ class APIConnection:
for msg_type in msg_types: for msg_type in msg_types:
with suppress(ValueError): with suppress(ValueError):
self._message_handlers[msg_type].remove(on_message) self._message_handlers[msg_type].remove(on_message)
with suppress(ValueError): self._read_exception_futures.discard(fut)
self._read_exception_handlers.remove(on_read_exception)
return responses return responses
@ -678,9 +681,15 @@ class APIConnection:
) )
self._fatal_exception = err self._fatal_exception = err
self._connection_state = ConnectionState.CLOSED self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]: for fut in self._read_exception_futures:
handler(err) if fut.done():
self._read_exception_handlers.clear() continue
new_exc = err
if not isinstance(err, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed")
new_exc.__cause__ = err
fut.set_exception(new_exc)
self._read_exception_futures.clear()
self._cleanup() self._cleanup()
def _process_packet(self, msg_type_proto: int, data: bytes) -> None: def _process_packet(self, msg_type_proto: int, data: bytes) -> None: