Reduce protocol overhead (#454)

This commit is contained in:
J. Nick Koston 2023-07-03 11:57:04 -05:00 committed by GitHub
parent 0626fbe45f
commit f3f5bd6b55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 126 additions and 65 deletions

View File

@ -3,7 +3,8 @@ import base64
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any, Callable, Optional, Union, cast from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
import async_timeout import async_timeout
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
@ -149,54 +150,68 @@ class APIPlaintextFrameHelper(APIFrameHelper):
"""Perform the handshake.""" """Perform the handshake."""
await self._connected_event.wait() await self._connected_event.wait()
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data self._buffer += data
while len(self._buffer) >= 3: while self._buffer:
# Read preamble, which should always 0x00 # Read preamble, which should always 0x00
# Also try to get the length and msg type # Also try to get the length and msg type
# to avoid multiple calls to readexactly # to avoid multiple calls to _read_exactly
init_bytes = self._init_read(3) init_bytes = self._init_read(3)
assert init_bytes is not None, "Buffer should have at least 3 bytes" if init_bytes is None:
if init_bytes[0] != 0x00: return
if init_bytes[0] == 0x01: msg_type_int: Optional[int] = None
length_int: Optional[int] = None
preamble, length_high, maybe_msg_type = init_bytes
if preamble != 0x00:
if preamble == 0x01:
self._handle_error_and_close( self._handle_error_and_close(
RequiresEncryptionAPIError("Connection requires encryption") RequiresEncryptionAPIError("Connection requires encryption")
) )
return return
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") ProtocolAPIError(f"Invalid preamble {preamble:02x}")
) )
return return
if init_bytes[1] & 0x80 == 0x80: if length_high & 0x80 != 0x80:
# Length is longer than 1 byte # Length is only 1 byte
length = init_bytes[1:3] #
msg_type = b"" # This is the most common case needing a single byte for
# length and type which means we avoid 2 calls to _read_exactly
length_int = length_high
if maybe_msg_type & 0x80 != 0x80:
# Message type is also only 1 byte
msg_type_int = maybe_msg_type
else: else:
# This is the most common case with 99% of messages # Message type is longer than 1 byte
# needing a single byte for length and type which means msg_type = bytes(init_bytes[2:3])
# we avoid 2 calls to readexactly else:
length = init_bytes[1:2] # Length is longer than 1 byte
msg_type = init_bytes[2:3] length = bytes(init_bytes[1:3])
# If the message is long, we need to read the rest of the length # If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80: while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1) add_length = self._read_exactly(1)
if add_length is None: if add_length is None:
return return
length += add_length length += add_length
length_int = bytes_to_varuint(length)
# Since the length is longer than 1 byte we do not have the
# message type yet.
msg_type = b""
# If the message length was longer than 1 byte, we need to read the # If the we do not have the message type yet because the message
# message type # length was so long it did not fit into the first byte we need
while not msg_type or (msg_type[-1] & 0x80) == 0x80: # to read the (rest) of the message type
if msg_type_int is None:
while not msg_type or msg_type[-1] & 0x80 == 0x80:
add_msg_type = self._read_exactly(1) add_msg_type = self._read_exactly(1)
if add_msg_type is None: if add_msg_type is None:
return return
msg_type += add_msg_type msg_type += add_msg_type
msg_type_int = bytes_to_varuint(msg_type)
length_int = bytes_to_varuint(bytes(length)) if TYPE_CHECKING:
assert length_int is not None assert length_int is not None
msg_type_int = bytes_to_varuint(bytes(msg_type))
assert msg_type_int is not None assert msg_type_int is not None
if length_int == 0: if length_int == 0:
@ -205,6 +220,11 @@ class APIPlaintextFrameHelper(APIFrameHelper):
continue continue
packet_data = self._read_exactly(length_int) packet_data = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if packet_data is None: if packet_data is None:
return return
@ -246,6 +266,8 @@ class APINoiseFrameHelper(APIFrameHelper):
"_state", "_state",
"_server_name", "_server_name",
"_proto", "_proto",
"_decrypt",
"_encrypt",
) )
def __init__( def __init__(
@ -262,6 +284,8 @@ class APINoiseFrameHelper(APIFrameHelper):
self._expected_name = expected_name self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO self._state = NoiseConnectionState.HELLO
self._server_name: Optional[str] = None self._server_name: Optional[str] = None
self._decrypt: Optional[Callable[[bytes], bytes]] = None
self._encrypt: Optional[Callable[[bytes], bytes]] = None
self._setup_proto() self._setup_proto()
def _set_ready_future_exception(self, exc: Exception) -> None: def _set_ready_future_exception(self, exc: Exception) -> None:
@ -291,12 +315,13 @@ class APINoiseFrameHelper(APIFrameHelper):
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending frame: [%s]", frame.hex()) _LOGGER.debug("Sending frame: [%s]", frame.hex())
frame_len = len(frame)
try: try:
header = bytes( header = bytes(
[ [
0x01, 0x01,
(len(frame) >> 8) & 0xFF, (frame_len >> 8) & 0xFF,
len(frame) & 0xFF, frame_len & 0xFF,
] ]
) )
self._transport.write(header + frame) self._transport.write(header + frame)
@ -314,17 +339,22 @@ class APINoiseFrameHelper(APIFrameHelper):
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
self._buffer += data self._buffer += data
while len(self._buffer) >= 3: while self._buffer:
header = self._init_read(3) header = self._init_read(3)
assert header is not None, "Buffer should have at least 3 bytes" if header is None:
if header[0] != 0x01: return
preamble, msg_size_high, msg_size_low = header
if preamble != 0x01:
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}") ProtocolAPIError(f"Marker byte invalid: {header[0]}")
) )
return return
msg_size = (header[1] << 8) | header[2] frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
frame = self._read_exactly(msg_size) # The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if frame is None: if frame is None:
return return
@ -415,44 +445,62 @@ class APINoiseFrameHelper(APIFrameHelper):
return return
_LOGGER.debug("Handshake complete") _LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY self._state = NoiseConnectionState.READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
None,
)
self._encrypt = partial(
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._ready_future.set_result(None) self._ready_future.set_result(None)
def write_packet(self, type_: int, data: bytes) -> None: def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket.""" """Write a packet to the socket."""
if self._state != NoiseConnectionState.READY: if self._state != NoiseConnectionState.READY:
raise HandshakeAPIError("Noise connection is not ready") raise HandshakeAPIError("Noise connection is not ready")
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
data_len = len(data)
self._write_frame( self._write_frame(
self._proto.encrypt( self._encrypt(
(
bytes( bytes(
[ [
(type_ >> 8) & 0xFF, (type_ >> 8) & 0xFF,
(type_ >> 0) & 0xFF, (type_ >> 0) & 0xFF,
(len(data) >> 8) & 0xFF, (data_len >> 8) & 0xFF,
(len(data) >> 0) & 0xFF, (data_len >> 0) & 0xFF,
] ]
) )
+ data + data
) )
) )
)
def _handle_frame(self, frame: bytearray) -> None: def _handle_frame(self, frame: bytearray) -> None:
"""Handle an incoming frame.""" """Handle an incoming frame."""
assert self._proto is not None if TYPE_CHECKING:
msg = self._proto.decrypt(bytes(frame)) assert self._decrypt is not None, "Handshake should be complete"
if len(msg) < 4: try:
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}")) msg = self._decrypt(bytes(frame))
return except InvalidTag as ex:
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") ProtocolAPIError(f"Bad encryption frame: {ex!r}")
) )
return return
data = msg[4 : 4 + data_len] msg_len = len(msg)
self._on_pkt(pkt_type, data) if msg_len < 4:
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
return
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
msg_type = (msg_type_high << 8) | msg_type_low
data_len = (data_len_high << 8) | data_len_low
if data_len + 4 != msg_len:
self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
)
return
self._on_pkt(msg_type, msg[4:])
def _handle_closed( # pylint: disable=unused-argument def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray self, frame: bytearray

View File

@ -43,6 +43,19 @@ PREAMBLE = b"\x00"
(b"\x42" * 256), (b"\x42" * 256),
256, 256,
), ),
(
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
b"\x42",
32768,
),
(
PREAMBLE
+ varuint_to_bytes(32768)
+ varuint_to_bytes(32768)
+ (b"\x42" * 32768),
(b"\x42" * 32768),
32768,
),
], ],
) )
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):