Optimize throughput of api to decrease latency (#327)

This commit is contained in:
J. Nick Koston 2022-12-02 09:12:19 -10:00 committed by GitHub
parent e2527878cb
commit 3692478455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 195 additions and 128 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import base64
import logging
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass
from typing import Optional
@ -29,101 +29,129 @@ class Packet:
class APIFrameHelper(ABC):
@abstractmethod
async def close(self) -> None:
pass
"""Helper class to handle the API frame protocol."""
@abstractmethod
async def write_packet(self, packet: Packet) -> None:
pass
@abstractmethod
async def read_packet(self) -> Packet:
pass
class APIPlaintextFrameHelper(APIFrameHelper):
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
) -> None:
"""Initialize the API frame helper."""
self._reader = reader
self._writer = writer
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self.read_lock = asyncio.Lock()
self._closed_event = asyncio.Event()
@abstractproperty # pylint: disable=deprecated-decorator
def ready(self) -> bool:
"""Return if the connection is ready."""
@abstractmethod
async def close(self) -> None:
"""Close the connection."""
@abstractmethod
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
@abstractmethod
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
@abstractmethod
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
async def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
self._writer.close()
async def write_packet(self, packet: Packet) -> None:
data = b"\0"
data += varuint_to_bytes(len(packet.data))
data += varuint_to_bytes(packet.type)
data += packet.data
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
# Plaintext is always ready
return True
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
data = (
b"\0"
+ varuint_to_bytes(len(packet.data))
+ varuint_to_bytes(packet.type)
+ packet.data
)
_LOGGER.debug("Sending plaintext frame %s", data.hex())
try:
async with self._write_lock:
_LOGGER.debug("Sending plaintext frame %s", data.hex())
self._writer.write(data)
await self._writer.drain()
self._writer.write(data)
except (ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def read_packet(self) -> Packet:
async with self._read_lock:
try:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to readexactly
init_bytes = await self._reader.readexactly(3)
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError(
"Connection requires encryption"
)
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
# No handshake for plaintext
if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte
length = init_bytes[1:3]
msg_type = b""
else:
# This is the most common case with 99% of messages
# needing a single byte for length and type which means
# we avoid 2 calls to readexactly
length = init_bytes[1:2]
msg_type = init_bytes[2:3]
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
try:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to readexactly
init_bytes = await self._reader.readexactly(3)
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption")
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
length += await self._reader.readexactly(1)
if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte
length = init_bytes[1:3]
msg_type = b""
else:
# This is the most common case with 99% of messages
# needing a single byte for length and type which means
# we avoid 2 calls to readexactly
length = init_bytes[1:2]
msg_type = init_bytes[2:3]
# If the message length was longer than 1 byte, we need to read the
# message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
# If the message length was longer than 1 byte, we need to read the
# message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
if length_int == 0:
return Packet(type=msg_type_int, data=b"")
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
data = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=data)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
if length_int == 0:
return Packet(type=msg_type_int, data=b"")
data = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=data)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
def _decode_noise_psk(psk: str) -> bytes:
@ -142,49 +170,63 @@ def _decode_noise_psk(psk: str) -> bytes:
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
noise_psk: str,
):
self._reader = reader
self._writer = writer
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
) -> None:
"""Initialize the API frame helper."""
super().__init__(reader, writer)
self._ready_event = asyncio.Event()
self._closed_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
self._noise_psk = noise_psk
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
return self._ready_event.is_set()
async def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._ready_event.set()
self._closed_event.set()
self._writer.close()
async def _write_frame(self, frame: bytes) -> None:
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
_LOGGER.debug("Sending frame %s", frame.hex())
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", frame.hex())
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
await self._writer.drain()
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame(self) -> bytes:
async def _read_frame_with_lock(self) -> bytes:
"""Read a frame from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
@ -199,10 +241,12 @@ class APINoiseFrameHelper(APIFrameHelper):
return frame
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
await self._write_frame(b"") # ClientHello
"""Perform the handshake with the server, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_perform_handshake called without lock"
self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame() # ServerHello
server_hello = await self._read_frame_with_lock() # ServerHello
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
@ -238,9 +282,9 @@ class APINoiseFrameHelper(APIFrameHelper):
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
await self._write_frame(b"\x00" + msg)
self._write_frame(b"\x00" + msg)
else:
msg = await self._read_frame()
msg = await self._read_frame_with_lock()
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
@ -256,16 +300,16 @@ class APINoiseFrameHelper(APIFrameHelper):
self._ready_event.set()
async def perform_handshake(self, expected_name: Optional[str]) -> None:
"""Perform the handshake with the server."""
# Allow up to 60 seconds for handhsake
try:
async with async_timeout.timeout(60.0):
async with self.read_lock, async_timeout.timeout(60.0):
await self._perform_handshake(expected_name)
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
async def write_packet(self, packet: Packet) -> None:
# Wait for handshake to complete
await self._ready_event.wait()
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
padding = 0
data = (
bytes(
@ -281,12 +325,15 @@ class APINoiseFrameHelper(APIFrameHelper):
)
assert self._proto is not None
frame = self._proto.encrypt(data)
await self._write_frame(frame)
self._write_frame(frame)
async def read_packet(self) -> Packet:
# Wait for handshake to complete
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
await self._ready_event.wait()
frame = await self._read_frame()
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
frame = await self._read_frame_with_lock()
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:

View File

@ -49,6 +49,8 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
INTERNAL_MESSAGE_TYPES = (GetTimeRequest, PingRequest, DisconnectRequest)
@dataclass
class ConnectionParams:
@ -105,7 +107,7 @@ class APIConnection:
self._ping_stop_event = asyncio.Event()
self._to_process: asyncio.Queue[Packet] = asyncio.Queue()
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
self._process_task: Optional[asyncio.Task[None]] = None
@ -120,6 +122,9 @@ class APIConnection:
async def _do_cleanup() -> None:
async with self._connect_lock:
# Tell the process loop to stop
self._to_process.put_nowait(None)
if self._frame_helper is not None:
await self._frame_helper.close()
self._frame_helper = None
@ -388,10 +393,13 @@ class APIConnection:
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
frame_helper = self._frame_helper
assert frame_helper is not None
if not frame_helper.ready:
await frame_helper.wait_for_ready()
try:
assert self._frame_helper is not None
# pylint: disable=undefined-loop-variable
await self._frame_helper.write_packet(
frame_helper.write_packet(
Packet(
type=message_type,
data=encoded,
@ -512,48 +520,60 @@ class APIConnection:
await self._cleanup()
async def _process_loop(self) -> None:
to_process = self._to_process
while True:
if not self._is_socket_open:
# Socket closed but task isn't cancelled yet
break
try:
pkt = await self._to_process.get()
pkt = await to_process.get()
except RuntimeError:
break
if pkt is None:
# Socket closed but task isn't cancelled yet
break
msg_type = pkt.type
raw_msg = pkt.data
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
try:
msg.ParseFromString(raw_msg)
msg.ParseFromString(pkt.data)
except Exception as e:
await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
raise
_LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, type(msg), msg
)
for handler in self._message_handlers[:]:
handler(msg)
await self._handle_internal_messages(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if isinstance(msg, INTERNAL_MESSAGE_TYPES):
await self._handle_internal_messages(msg)
async def _read_loop(self) -> None:
assert self._frame_helper is not None
frame_helper = self._frame_helper
assert frame_helper is not None
await frame_helper.wait_for_ready()
to_process = self._to_process
try:
while True:
if not self._is_socket_open:
# Socket closed but task isn't cancelled yet
break
self._to_process.put_nowait(await self._frame_helper.read_packet())
# Once its ready, we hold the lock for the duration of the
# connection so we don't have to keep locking/unlocking
async with frame_helper.read_lock:
while True:
to_process.put_nowait(await frame_helper.read_packet_with_lock())
except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log
if not self._is_socket_open:
# If we expected the socket to be closed, don't log
# the error.
return
_LOGGER.debug(
"%s: Socket closed, stopping read loop",
self.log_name,