Optimize throughput of api to decrease latency (#327)

This commit is contained in:
J. Nick Koston 2022-12-02 09:12:19 -10:00 committed by GitHub
parent e2527878cb
commit 3692478455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 195 additions and 128 deletions

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import base64 import base64
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@ -29,50 +29,80 @@ class Packet:
class APIFrameHelper(ABC): class APIFrameHelper(ABC):
@abstractmethod """Helper class to handle the API frame protocol."""
async def close(self) -> None:
pass
@abstractmethod
async def write_packet(self, packet: Packet) -> None:
pass
@abstractmethod
async def read_packet(self) -> Packet:
pass
class APIPlaintextFrameHelper(APIFrameHelper):
def __init__( def __init__(
self, self,
reader: asyncio.StreamReader, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, writer: asyncio.StreamWriter,
): ) -> None:
"""Initialize the API frame helper."""
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
self._write_lock = asyncio.Lock() self.read_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._closed_event = asyncio.Event() self._closed_event = asyncio.Event()
@abstractproperty # pylint: disable=deprecated-decorator
def ready(self) -> bool:
"""Return if the connection is ready."""
@abstractmethod
async def close(self) -> None: async def close(self) -> None:
"""Close the connection."""
@abstractmethod
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
@abstractmethod
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
@abstractmethod
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
async def close(self) -> None:
"""Close the connection."""
self._closed_event.set() self._closed_event.set()
self._writer.close() self._writer.close()
async def write_packet(self, packet: Packet) -> None: @property
data = b"\0" def ready(self) -> bool:
data += varuint_to_bytes(len(packet.data)) """Return if the connection is ready."""
data += varuint_to_bytes(packet.type) # Plaintext is always ready
data += packet.data return True
try:
async with self._write_lock: def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
data = (
b"\0"
+ varuint_to_bytes(len(packet.data))
+ varuint_to_bytes(packet.type)
+ packet.data
)
_LOGGER.debug("Sending plaintext frame %s", data.hex()) _LOGGER.debug("Sending plaintext frame %s", data.hex())
try:
self._writer.write(data) self._writer.write(data)
await self._writer.drain()
except (ConnectionResetError, OSError) as err: except (ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def read_packet(self) -> Packet: async def wait_for_ready(self) -> None:
async with self._read_lock: """Wait for the connection to be ready."""
# No handshake for plaintext
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
try: try:
# Read preamble, which should always 0x00 # Read preamble, which should always 0x00
# Also try to get the length and msg type # Also try to get the length and msg type
@ -80,9 +110,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
init_bytes = await self._reader.readexactly(3) init_bytes = await self._reader.readexactly(3)
if init_bytes[0] != 0x00: if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01: if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError( raise RequiresEncryptionAPIError("Connection requires encryption")
"Connection requires encryption"
)
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
if init_bytes[1] & 0x80 == 0x80: if init_bytes[1] & 0x80 == 0x80:
@ -142,29 +170,43 @@ def _decode_noise_psk(psk: str) -> bytes:
class APINoiseFrameHelper(APIFrameHelper): class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
def __init__( def __init__(
self, self,
reader: asyncio.StreamReader, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, writer: asyncio.StreamWriter,
noise_psk: str, noise_psk: str,
): ) -> None:
self._reader = reader """Initialize the API frame helper."""
self._writer = writer super().__init__(reader, writer)
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._ready_event = asyncio.Event() self._ready_event = asyncio.Event()
self._closed_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None self._proto: Optional[NoiseConnection] = None
self._noise_psk = noise_psk self._noise_psk = noise_psk
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
return self._ready_event.is_set()
async def close(self) -> None: async def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._ready_event.set()
self._closed_event.set() self._closed_event.set()
self._writer.close() self._writer.close()
async def _write_frame(self, frame: bytes) -> None: def _write_frame(self, frame: bytes) -> None:
try: """Write a packet to the socket, the caller should not have the lock.
async with self._write_lock:
The entire packet must be written in a single call to write
to avoid locking.
"""
_LOGGER.debug("Sending frame %s", frame.hex()) _LOGGER.debug("Sending frame %s", frame.hex())
try:
header = bytes( header = bytes(
[ [
0x01, 0x01,
@ -173,13 +215,13 @@ class APINoiseFrameHelper(APIFrameHelper):
] ]
) )
self._writer.write(header + frame) self._writer.write(header + frame)
await self._writer.drain()
except OSError as err: except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame(self) -> bytes: async def _read_frame_with_lock(self) -> bytes:
"""Read a frame from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
try: try:
async with self._read_lock:
header = await self._reader.readexactly(3) header = await self._reader.readexactly(3)
if header[0] != 0x01: if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
@ -199,10 +241,12 @@ class APINoiseFrameHelper(APIFrameHelper):
return frame return frame
async def _perform_handshake(self, expected_name: Optional[str]) -> None: async def _perform_handshake(self, expected_name: Optional[str]) -> None:
await self._write_frame(b"") # ClientHello """Perform the handshake with the server, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_perform_handshake called without lock"
self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00" prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame() # ServerHello server_hello = await self._read_frame_with_lock() # ServerHello
if not server_hello: if not server_hello:
raise HandshakeAPIError("ServerHello is empty") raise HandshakeAPIError("ServerHello is empty")
@ -238,9 +282,9 @@ class APINoiseFrameHelper(APIFrameHelper):
while not self._proto.handshake_finished: while not self._proto.handshake_finished:
if do_write: if do_write:
msg = self._proto.write_message() msg = self._proto.write_message()
await self._write_frame(b"\x00" + msg) self._write_frame(b"\x00" + msg)
else: else:
msg = await self._read_frame() msg = await self._read_frame_with_lock()
if not msg: if not msg:
raise HandshakeAPIError("Handshake message too short") raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0: if msg[0] != 0:
@ -256,16 +300,16 @@ class APINoiseFrameHelper(APIFrameHelper):
self._ready_event.set() self._ready_event.set()
async def perform_handshake(self, expected_name: Optional[str]) -> None: async def perform_handshake(self, expected_name: Optional[str]) -> None:
"""Perform the handshake with the server."""
# Allow up to 60 seconds for handhsake # Allow up to 60 seconds for handhsake
try: try:
async with async_timeout.timeout(60.0): async with self.read_lock, async_timeout.timeout(60.0):
await self._perform_handshake(expected_name) await self._perform_handshake(expected_name)
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err raise HandshakeAPIError("Timeout during handshake") from err
async def write_packet(self, packet: Packet) -> None: def write_packet(self, packet: Packet) -> None:
# Wait for handshake to complete """Write a packet to the socket."""
await self._ready_event.wait()
padding = 0 padding = 0
data = ( data = (
bytes( bytes(
@ -281,12 +325,15 @@ class APINoiseFrameHelper(APIFrameHelper):
) )
assert self._proto is not None assert self._proto is not None
frame = self._proto.encrypt(data) frame = self._proto.encrypt(data)
await self._write_frame(frame) self._write_frame(frame)
async def read_packet(self) -> Packet: async def wait_for_ready(self) -> None:
# Wait for handshake to complete """Wait for the connection to be ready."""
await self._ready_event.wait() await self._ready_event.wait()
frame = await self._read_frame()
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
frame = await self._read_frame_with_lock()
assert self._proto is not None assert self._proto is not None
msg = self._proto.decrypt(frame) msg = self._proto.decrypt(frame)
if len(msg) < 4: if len(msg) < 4:

View File

@ -49,6 +49,8 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
@dataclass @dataclass
class ConnectionParams: class ConnectionParams:
@ -105,7 +107,7 @@ class APIConnection:
self._ping_stop_event = asyncio.Event() self._ping_stop_event = asyncio.Event()
self._to_process: asyncio.Queue[Packet] = asyncio.Queue() self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
self._process_task: Optional[asyncio.Task[None]] = None self._process_task: Optional[asyncio.Task[None]] = None
@ -120,6 +122,9 @@ class APIConnection:
async def _do_cleanup() -> None: async def _do_cleanup() -> None:
async with self._connect_lock: async with self._connect_lock:
# Tell the process loop to stop
self._to_process.put_nowait(None)
if self._frame_helper is not None: if self._frame_helper is not None:
await self._frame_helper.close() await self._frame_helper.close()
self._frame_helper = None self._frame_helper = None
@ -388,10 +393,13 @@ class APIConnection:
encoded = msg.SerializeToString() encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
frame_helper = self._frame_helper
assert frame_helper is not None
if not frame_helper.ready:
await frame_helper.wait_for_ready()
try: try:
assert self._frame_helper is not None frame_helper.write_packet(
# pylint: disable=undefined-loop-variable
await self._frame_helper.write_packet(
Packet( Packet(
type=message_type, type=message_type,
data=encoded, data=encoded,
@ -512,48 +520,60 @@ class APIConnection:
await self._cleanup() await self._cleanup()
async def _process_loop(self) -> None: async def _process_loop(self) -> None:
to_process = self._to_process
while True: while True:
if not self._is_socket_open:
# Socket closed but task isn't cancelled yet
break
try: try:
pkt = await self._to_process.get() pkt = await to_process.get()
except RuntimeError: except RuntimeError:
break break
if pkt is None:
# Socket closed but task isn't cancelled yet
break
msg_type = pkt.type msg_type = pkt.type
raw_msg = pkt.data
if msg_type not in MESSAGE_TYPE_TO_PROTO: if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type) _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
continue continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type]() msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
try: try:
msg.ParseFromString(raw_msg) msg.ParseFromString(pkt.data)
except Exception as e: except Exception as e:
await self._report_fatal_error( await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}") ProtocolAPIError(f"Invalid protobuf message: {e}")
) )
raise raise
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, type(msg), msg "%s: Got message of type %s: %s", self.log_name, type(msg), msg
) )
for handler in self._message_handlers[:]: for handler in self._message_handlers[:]:
handler(msg) handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
await self._handle_internal_messages(msg) await self._handle_internal_messages(msg)
async def _read_loop(self) -> None: async def _read_loop(self) -> None:
assert self._frame_helper is not None frame_helper = self._frame_helper
assert frame_helper is not None
await frame_helper.wait_for_ready()
to_process = self._to_process
try: try:
# Once its ready, we hold the lock for the duration of the
# connection so we don't have to keep locking/unlocking
async with frame_helper.read_lock:
while True: while True:
if not self._is_socket_open: to_process.put_nowait(await frame_helper.read_packet_with_lock())
# Socket closed but task isn't cancelled yet
break
self._to_process.put_nowait(await self._frame_helper.read_packet())
except SocketClosedAPIError as err: except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log # don't log with info, if closed the site that closed the connection should log
if not self._is_socket_open:
# If we expected the socket to be closed, don't log
# the error.
return
_LOGGER.debug( _LOGGER.debug(
"%s: Socket closed, stopping read loop", "%s: Socket closed, stopping read loop",
self.log_name, self.log_name,