Reduce overhead to handle read exceptions (#472)

This commit is contained in:
J. Nick Koston 2023-07-15 08:18:31 -10:00 committed by GitHub
parent 0dfaa58f07
commit ce07e11e93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 29 additions and 20 deletions

View File

@ -6,7 +6,18 @@ import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
Type,
Union,
)
import async_timeout
from google.protobuf import message
@ -127,7 +138,7 @@ class APIConnection:
"_connect_complete",
"_message_handlers",
"log_name",
"_read_exception_handlers",
"_read_exception_futures",
"_ping_timer",
"_pong_timer",
"_keep_alive_interval",
@ -163,8 +174,8 @@ class APIConnection:
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
# Handlers currently subscribed to exceptions in the read task
self._read_exception_handlers: List[Callable[[Exception], None]] = []
# futures currently subscribed to exceptions in the read task
self._read_exception_futures: Set[asyncio.Future[None]] = set()
self._ping_timer: Optional[asyncio.TimerHandle] = None
self._pong_timer: Optional[asyncio.TimerHandle] = None
@ -601,7 +612,7 @@ class APIConnection:
:raises TimeoutAPIError: if a timeout occured
"""
fut = self._loop.create_future()
fut: asyncio.Future[None] = self._loop.create_future()
responses = []
def on_message(resp: message.Message) -> None:
@ -610,19 +621,12 @@ class APIConnection:
if do_append(resp):
responses.append(resp)
if do_stop(resp):
fut.set_result(responses)
def on_read_exception(exc: Exception) -> None:
if not fut.done():
new_exc = exc
if not isinstance(exc, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed")
new_exc.__cause__ = exc
fut.set_exception(new_exc)
fut.set_result(None)
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
self._read_exception_handlers.append(on_read_exception)
self._read_exception_futures.add(fut)
# We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
@ -639,8 +643,7 @@ class APIConnection:
for msg_type in msg_types:
with suppress(ValueError):
self._message_handlers[msg_type].remove(on_message)
with suppress(ValueError):
self._read_exception_handlers.remove(on_read_exception)
self._read_exception_futures.discard(fut)
return responses
@ -678,9 +681,15 @@ class APIConnection:
)
self._fatal_exception = err
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._read_exception_handlers.clear()
for fut in self._read_exception_futures:
if fut.done():
continue
new_exc = err
if not isinstance(err, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed")
new_exc.__cause__ = err
fut.set_exception(new_exc)
self._read_exception_futures.clear()
self._cleanup()
def _process_packet(self, msg_type_proto: int, data: bytes) -> None: