Improve performance reassembling fragmented packets (#461)

This commit is contained in:
J. Nick Koston 2023-07-09 12:10:33 -10:00 committed by GitHub
parent 74ba67d792
commit b81fe760ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 44 deletions

View File

@ -4,7 +4,7 @@ import logging
from abc import abstractmethod
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
import async_timeout
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
@ -59,6 +59,7 @@ class APIFrameHelper(asyncio.Protocol):
"_transport",
"_connected_event",
"_buffer",
"_buffer_len",
"_pos",
)
@ -73,18 +74,14 @@ class APIFrameHelper(asyncio.Protocol):
self._transport: Optional[asyncio.Transport] = None
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._buffer_len = 0
self._pos = 0
def _init_read(self, length: int) -> Optional[bytearray]:
"""Start reading a packet from the buffer."""
self._pos = 0
return self._read_exactly(length)
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
if len(self._buffer) < new_pos:
if self._buffer_len < new_pos:
return None
self._pos = new_pos
return self._buffer[original_pos:new_pos]
@ -126,11 +123,6 @@ class APIFrameHelper(asyncio.Protocol):
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
self._on_pkt(type_, data)
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
@ -153,11 +145,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data
self._buffer_len += len(data)
while self._buffer:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to _read_exactly
init_bytes = self._init_read(3)
self._pos = 0
init_bytes = self._read_exactly(3)
if init_bytes is None:
return
msg_type_int: Optional[int] = None
@ -216,20 +210,22 @@ class APIPlaintextFrameHelper(APIFrameHelper):
assert msg_type_int is not None
if length_int == 0:
self._callback_packet(msg_type_int, b"")
# If we have more data, continue processing
continue
packet_data = b""
else:
packet_data_bytearray = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if packet_data_bytearray is None:
return
packet_data = bytes(packet_data_bytearray)
packet_data = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if packet_data is None:
return
self._callback_packet(msg_type_int, bytes(packet_data))
end_of_frame_pos = self._pos
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
self._on_pkt(msg_type_int, packet_data)
# If we have more data, continue processing
@ -340,8 +336,10 @@ class APINoiseFrameHelper(APIFrameHelper):
def data_received(self, data: bytes) -> None:
self._buffer += data
self._buffer_len += len(data)
while self._buffer:
header = self._init_read(3)
self._pos = 0
header = self._read_exactly(3)
if header is None:
return
preamble, msg_size_high, msg_size_low = header
@ -364,7 +362,9 @@ class APINoiseFrameHelper(APIFrameHelper):
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
del self._buffer[: self._pos]
end_of_frame_pos = self._pos
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
def _send_hello(self) -> None:
"""Send a ClientHello to the server."""
@ -469,9 +469,9 @@ class APINoiseFrameHelper(APIFrameHelper):
bytes(
[
(type_ >> 8) & 0xFF,
(type_ >> 0) & 0xFF,
type_ & 0xFF,
(data_len >> 8) & 0xFF,
(data_len >> 0) & 0xFF,
data_len & 0xFF,
]
)
+ data
@ -489,19 +489,11 @@ class APINoiseFrameHelper(APIFrameHelper):
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
)
return
msg_len = len(msg)
if msg_len < 4:
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
return
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
msg_type = (msg_type_high << 8) | msg_type_low
data_len = (data_len_high << 8) | data_len_low
if data_len + 4 != msg_len:
self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
)
return
self._on_pkt(msg_type, msg[4:])
# Message layout is
# 2 bytes: message type
# 2 bytes: message length
# N bytes: message data
self._on_pkt((msg[0] << 8) | msg[1], msg[4:])
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray

View File

@ -59,7 +59,7 @@ PREAMBLE = b"\x00"
],
)
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
for _ in range(5):
for _ in range(3):
packets = []
def _packet(type_: int, data: bytes):
@ -78,6 +78,16 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
assert type_ == pkt_type
assert data == pkt_data
# Make sure we correctly handle fragments
for i in range(len(in_bytes)):
helper.data_received(in_bytes[i : i + 1])
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
@pytest.mark.asyncio
async def test_noise_frame_helper_incorrect_key():
@ -117,6 +127,46 @@ async def test_noise_frame_helper_incorrect_key():
await helper.perform_handshake()
@pytest.mark.asyncio
async def test_noise_frame_helper_incorrect_key_fragments():
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key with fragmented packets."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
)
helper._transport = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
in_pkt = bytes.fromhex(pkt)
for i in range(len(in_pkt)):
helper.data_received(in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake()
@pytest.mark.asyncio
async def test_noise_incorrect_name():
"""Test we raise on bad name."""