Reduce memory copy when buffering is not required (#572)

This commit is contained in:
J. Nick Koston 2023-10-12 15:17:46 -10:00 committed by GitHub
parent 7323ce8987
commit 20ddb972e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 36 deletions

View File

@ -2,6 +2,8 @@
import cython import cython
cdef bint TYPE_CHECKING
cdef class APIFrameHelper: cdef class APIFrameHelper:
cdef object _loop cdef object _loop
@ -10,7 +12,7 @@ cdef class APIFrameHelper:
cdef object _transport cdef object _transport
cdef public object _writer cdef public object _writer
cdef public object _ready_future cdef public object _ready_future
cdef bytearray _buffer cdef bytes _buffer
cdef cython.uint _buffer_len cdef cython.uint _buffer_len
cdef cython.uint _pos cdef cython.uint _pos
cdef object _client_info cdef object _client_info
@ -18,4 +20,9 @@ cdef class APIFrameHelper:
cdef object _debug_enabled cdef object _debug_enabled
@cython.locals(original_pos=cython.uint, new_pos=cython.uint) @cython.locals(original_pos=cython.uint, new_pos=cython.uint)
cdef _read_exactly(self, int length) cdef bytes _read_exactly(self, int length)
cdef _add_to_buffer(self, bytes data)
@cython.locals(end_of_frame_pos=cython.uint)
cdef _remove_from_buffer(self)

View File

@ -4,7 +4,7 @@ import asyncio
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
from typing import Callable, cast from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError from ..core import HandshakeAPIError, SocketClosedAPIError
@ -55,7 +55,7 @@ class APIFrameHelper:
self._transport: asyncio.Transport | None = None self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future() self._ready_future = self._loop.create_future()
self._buffer = bytearray() self._buffer: bytes | None = None
self._buffer_len = 0 self._buffer_len = 0
self._pos = 0 self._pos = 0
self._client_info = client_info self._client_info = client_info
@ -66,13 +66,48 @@ class APIFrameHelper:
if not self._ready_future.done(): if not self._ready_future.done():
self._ready_future.set_exception(exc) self._ready_future.set_exception(exc)
def _read_exactly(self, length: _int) -> bytearray | None: def _add_to_buffer(self, data: bytes) -> None:
"""Add data to the buffer."""
if self._buffer_len == 0:
# This is the best case scenario, we don't have to copy the data
# and can just use the buffer directly. This is the most common
# case as well.
self._buffer = data
else:
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
# This is the worst case scenario, we have to copy the data
# and can't just use the buffer directly. This is also very
# uncommon since we usually read the entire frame at once.
self._buffer += data
self._buffer_len += len(data)
def _remove_from_buffer(self) -> None:
"""Remove data from the buffer."""
end_of_frame_pos = self._pos
self._buffer_len -= end_of_frame_pos
if self._buffer_len == 0:
# This is the best case scenario, we can just set the buffer to None
# and don't have to copy the data. This is the most common case as well.
self._buffer = None
return
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
# This is the worst case scenario, we have to copy the data
# and can't just use the buffer directly. This should only happen
# when we read multiple frames at once because the event loop
# is blocked and we cannot pull the data out of the buffer fast enough.
self._buffer = self._buffer[end_of_frame_pos:]
def _read_exactly(self, length: _int) -> bytes | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos original_pos = self._pos
new_pos = original_pos + length new_pos = original_pos + length
if self._buffer_len < new_pos: if self._buffer_len < new_pos:
return None return None
self._pos = new_pos self._pos = new_pos
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
return self._buffer[original_pos:new_pos] return self._buffer[original_pos:new_pos]
async def perform_handshake(self, timeout: float) -> None: async def perform_handshake(self, timeout: float) -> None:

View File

@ -3,7 +3,7 @@ import cython
from .base cimport APIFrameHelper from .base cimport APIFrameHelper
cdef object TYPE_CHECKING cdef bint TYPE_CHECKING
cdef class APINoiseFrameHelper(APIFrameHelper): cdef class APINoiseFrameHelper(APIFrameHelper):
@ -18,10 +18,15 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
cdef bint _is_ready cdef bint _is_ready
@cython.locals( @cython.locals(
header=bytearray, header=bytes,
preamble=cython.uint, preamble=cython.uint,
msg_size_high=cython.uint, msg_size_high=cython.uint,
msg_size_low=cython.uint, msg_size_low=cython.uint,
end_of_frame_pos=cython.uint,
) )
cpdef data_received(self, bytes data) cpdef data_received(self, bytes data)
@cython.locals(
type_high=cython.uint,
type_low=cython.uint
)
cpdef _handle_frame(self, bytes data)

View File

@ -139,8 +139,7 @@ class APINoiseFrameHelper(APIFrameHelper):
await super().perform_handshake(timeout) await super().perform_handshake(timeout)
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
self._buffer += data self._add_to_buffer(data)
self._buffer_len += len(data)
while self._buffer: while self._buffer:
self._pos = 0 self._pos = 0
header = self._read_exactly(3) header = self._read_exactly(3)
@ -170,9 +169,7 @@ class APINoiseFrameHelper(APIFrameHelper):
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err) self._handle_error_and_close(err)
finally: finally:
end_of_frame_pos = self._pos self._remove_from_buffer()
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
def _send_hello_handshake(self) -> None: def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server.""" """Send a ClientHello to the server."""
@ -198,7 +195,7 @@ class APINoiseFrameHelper(APIFrameHelper):
f"{self._log_name}: Error while writing data: {err}" f"{self._log_name}: Error while writing data: {err}"
) from err ) from err
def _handle_hello(self, server_hello: bytearray) -> None: def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server.""" """Perform the handshake with the server."""
if not server_hello: if not server_hello:
self._handle_error_and_close( self._handle_error_and_close(
@ -269,7 +266,7 @@ class APINoiseFrameHelper(APIFrameHelper):
proto.start_handshake() proto.start_handshake()
self._proto = proto self._proto = proto
def _handle_handshake(self, msg: bytearray) -> None: def _handle_handshake(self, msg: bytes) -> None:
_LOGGER.debug("Starting handshake...") _LOGGER.debug("Starting handshake...")
if msg[0] != 0: if msg[0] != 0:
explanation = msg[1:].decode() explanation = msg[1:].decode()
@ -333,12 +330,12 @@ class APINoiseFrameHelper(APIFrameHelper):
f"{self._log_name}: Error while writing data: {err}" f"{self._log_name}: Error while writing data: {err}"
) from err ) from err
def _handle_frame(self, frame: bytearray) -> None: def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame.""" """Handle an incoming frame."""
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete" assert self._decrypt is not None, "Handshake should be complete"
try: try:
msg = self._decrypt(bytes(frame)) msg = self._decrypt(frame)
except InvalidTag as ex: except InvalidTag as ex:
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}") ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
@ -348,11 +345,11 @@ class APINoiseFrameHelper(APIFrameHelper):
# 2 bytes: message type # 2 bytes: message type
# 2 bytes: message length # 2 bytes: message length
# N bytes: message data # N bytes: message data
self._on_pkt((msg[0] << 8) | msg[1], msg[4:]) type_high = msg[0]
type_low = msg[1]
self._on_pkt((type_high << 8) | type_low, msg[4:])
def _handle_closed( # pylint: disable=unused-argument def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
self, frame: bytearray
) -> None:
"""Handle a closed frame.""" """Handle a closed frame."""
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed")) self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))

View File

@ -3,7 +3,7 @@ import cython
from .base cimport APIFrameHelper from .base cimport APIFrameHelper
cdef object TYPE_CHECKING cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes cdef object bytes_to_varuint, varuint_to_bytes
@ -12,8 +12,8 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals( @cython.locals(
msg_type=bytes, msg_type=bytes,
length=bytes, length=bytes,
init_bytes=bytearray, init_bytes=bytes,
add_length=bytearray, add_length=bytes,
end_of_frame_pos=cython.uint, end_of_frame_pos=cython.uint,
length_int=cython.uint, length_int=cython.uint,
preamble=cython.uint, preamble=cython.uint,

View File

@ -39,8 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
) from err ) from err
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data self._add_to_buffer(data)
self._buffer_len += len(data)
while self._buffer: while self._buffer:
# Read preamble, which should always 0x00 # Read preamble, which should always 0x00
# Also try to get the length and msg type # Also try to get the length and msg type
@ -80,10 +79,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
msg_type_int = maybe_msg_type msg_type_int = maybe_msg_type
else: else:
# Message type is longer than 1 byte # Message type is longer than 1 byte
msg_type = bytes(init_bytes[2:3]) msg_type = init_bytes[2:3]
else: else:
# Length is longer than 1 byte # Length is longer than 1 byte
length = bytes(init_bytes[1:3]) length = init_bytes[1:3]
# If the message is long, we need to read the rest of the length # If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80: while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1) add_length = self._read_exactly(1)
@ -112,18 +111,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if length_int == 0: if length_int == 0:
packet_data = b"" packet_data = b""
else: else:
packet_data_bytearray = self._read_exactly(length_int) maybe_packet_data = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data # The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not # to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next # been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet # call to data_received will continue processing the packet
# at the start of the frame. # at the start of the frame.
if packet_data_bytearray is None: if maybe_packet_data is None:
return return
packet_data = bytes(packet_data_bytearray) packet_data = maybe_packet_data
end_of_frame_pos = self._pos self._remove_from_buffer()
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
self._on_pkt(msg_type_int, packet_data) self._on_pkt(msg_type_int, packet_data)
# If we have more data, continue processing # If we have more data, continue processing