Reduce memory copy when buffering is not required (#572)

This commit is contained in:
J. Nick Koston 2023-10-12 15:17:46 -10:00 committed by GitHub
parent 7323ce8987
commit 20ddb972e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 36 deletions

View File

@ -2,6 +2,8 @@
import cython
cdef bint TYPE_CHECKING
cdef class APIFrameHelper:
cdef object _loop
@ -10,7 +12,7 @@ cdef class APIFrameHelper:
cdef object _transport
cdef public object _writer
cdef public object _ready_future
cdef bytearray _buffer
cdef bytes _buffer
cdef cython.uint _buffer_len
cdef cython.uint _pos
cdef object _client_info
@ -18,4 +20,9 @@ cdef class APIFrameHelper:
cdef object _debug_enabled
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
cdef _read_exactly(self, int length)
cdef bytes _read_exactly(self, int length)
cdef _add_to_buffer(self, bytes data)
@cython.locals(end_of_frame_pos=cython.uint)
cdef _remove_from_buffer(self)

View File

@ -4,7 +4,7 @@ import asyncio
import logging
from abc import abstractmethod
from functools import partial
from typing import Callable, cast
from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError
@ -55,7 +55,7 @@ class APIFrameHelper:
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future()
self._buffer = bytearray()
self._buffer: bytes | None = None
self._buffer_len = 0
self._pos = 0
self._client_info = client_info
@ -66,13 +66,48 @@ class APIFrameHelper:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def _read_exactly(self, length: _int) -> bytearray | None:
def _add_to_buffer(self, data: bytes) -> None:
"""Add data to the buffer."""
if self._buffer_len == 0:
# This is the best case scenario, we don't have to copy the data
# and can just use the buffer directly. This is the most common
# case as well.
self._buffer = data
else:
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
# This is the worst case scenario, we have to copy the data
# and can't just use the buffer directly. This is also very
# uncommon since we usually read the entire frame at once.
self._buffer += data
self._buffer_len += len(data)
def _remove_from_buffer(self) -> None:
"""Remove data from the buffer."""
end_of_frame_pos = self._pos
self._buffer_len -= end_of_frame_pos
if self._buffer_len == 0:
# This is the best case scenario, we can just set the buffer to None
# and don't have to copy the data. This is the most common case as well.
self._buffer = None
return
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
# This is the worst case scenario, we have to copy the data
# and can't just use the buffer directly. This should only happen
# when we read multiple frames at once because the event loop
# is blocked and we cannot pull the data out of the buffer fast enough.
self._buffer = self._buffer[end_of_frame_pos:]
def _read_exactly(self, length: _int) -> bytes | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
if self._buffer_len < new_pos:
return None
self._pos = new_pos
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
return self._buffer[original_pos:new_pos]
async def perform_handshake(self, timeout: float) -> None:

View File

@ -3,7 +3,7 @@ import cython
from .base cimport APIFrameHelper
cdef object TYPE_CHECKING
cdef bint TYPE_CHECKING
cdef class APINoiseFrameHelper(APIFrameHelper):
@ -18,10 +18,15 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
cdef bint _is_ready
@cython.locals(
header=bytearray,
header=bytes,
preamble=cython.uint,
msg_size_high=cython.uint,
msg_size_low=cython.uint,
end_of_frame_pos=cython.uint,
)
cpdef data_received(self, bytes data)
cpdef data_received(self, bytes data)
@cython.locals(
type_high=cython.uint,
type_low=cython.uint
)
cpdef _handle_frame(self, bytes data)

View File

@ -139,8 +139,7 @@ class APINoiseFrameHelper(APIFrameHelper):
await super().perform_handshake(timeout)
def data_received(self, data: bytes) -> None:
self._buffer += data
self._buffer_len += len(data)
self._add_to_buffer(data)
while self._buffer:
self._pos = 0
header = self._read_exactly(3)
@ -170,9 +169,7 @@ class APINoiseFrameHelper(APIFrameHelper):
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
end_of_frame_pos = self._pos
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
self._remove_from_buffer()
def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server."""
@ -198,7 +195,7 @@ class APINoiseFrameHelper(APIFrameHelper):
f"{self._log_name}: Error while writing data: {err}"
) from err
def _handle_hello(self, server_hello: bytearray) -> None:
def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server."""
if not server_hello:
self._handle_error_and_close(
@ -269,7 +266,7 @@ class APINoiseFrameHelper(APIFrameHelper):
proto.start_handshake()
self._proto = proto
def _handle_handshake(self, msg: bytearray) -> None:
def _handle_handshake(self, msg: bytes) -> None:
_LOGGER.debug("Starting handshake...")
if msg[0] != 0:
explanation = msg[1:].decode()
@ -333,12 +330,12 @@ class APINoiseFrameHelper(APIFrameHelper):
f"{self._log_name}: Error while writing data: {err}"
) from err
def _handle_frame(self, frame: bytearray) -> None:
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
try:
msg = self._decrypt(bytes(frame))
msg = self._decrypt(frame)
except InvalidTag as ex:
self._handle_error_and_close(
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
@ -348,11 +345,11 @@ class APINoiseFrameHelper(APIFrameHelper):
# 2 bytes: message type
# 2 bytes: message length
# N bytes: message data
self._on_pkt((msg[0] << 8) | msg[1], msg[4:])
type_high = msg[0]
type_low = msg[1]
self._on_pkt((type_high << 8) | type_low, msg[4:])
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray
) -> None:
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))

View File

@ -3,7 +3,7 @@ import cython
from .base cimport APIFrameHelper
cdef object TYPE_CHECKING
cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes
@ -12,12 +12,12 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals(
msg_type=bytes,
length=bytes,
init_bytes=bytearray,
add_length=bytearray,
init_bytes=bytes,
add_length=bytes,
end_of_frame_pos=cython.uint,
length_int=cython.uint,
preamble=cython.uint,
length_high=cython.uint,
maybe_msg_type=cython.uint
)
cpdef data_received(self, bytes data)
cpdef data_received(self, bytes data)

View File

@ -39,8 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
) from err
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data
self._buffer_len += len(data)
self._add_to_buffer(data)
while self._buffer:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
@ -80,10 +79,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
msg_type_int = maybe_msg_type
else:
# Message type is longer than 1 byte
msg_type = bytes(init_bytes[2:3])
msg_type = init_bytes[2:3]
else:
# Length is longer than 1 byte
length = bytes(init_bytes[1:3])
length = init_bytes[1:3]
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1)
@ -112,18 +111,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if length_int == 0:
packet_data = b""
else:
packet_data_bytearray = self._read_exactly(length_int)
maybe_packet_data = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if packet_data_bytearray is None:
if maybe_packet_data is None:
return
packet_data = bytes(packet_data_bytearray)
packet_data = maybe_packet_data
end_of_frame_pos = self._pos
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
self._remove_from_buffer()
self._on_pkt(msg_type_int, packet_data)
# If we have more data, continue processing