Reduce protocol overhead (#454)

This commit is contained in:
J. Nick Koston 2023-07-03 11:57:04 -05:00 committed by GitHub
parent 0626fbe45f
commit f3f5bd6b55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 126 additions and 65 deletions

View File

@ -3,7 +3,8 @@ import base64
import logging
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Optional, Union, cast
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
import async_timeout
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
@ -149,55 +150,69 @@ class APIPlaintextFrameHelper(APIFrameHelper):
"""Perform the handshake."""
await self._connected_event.wait()
def data_received(self, data: bytes) -> None:
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data
while len(self._buffer) >= 3:
while self._buffer:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to readexactly
# to avoid multiple calls to _read_exactly
init_bytes = self._init_read(3)
assert init_bytes is not None, "Buffer should have at least 3 bytes"
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
if init_bytes is None:
return
msg_type_int: Optional[int] = None
length_int: Optional[int] = None
preamble, length_high, maybe_msg_type = init_bytes
if preamble != 0x00:
if preamble == 0x01:
self._handle_error_and_close(
RequiresEncryptionAPIError("Connection requires encryption")
)
return
self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
ProtocolAPIError(f"Invalid preamble {preamble:02x}")
)
return
if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte
length = init_bytes[1:3]
msg_type = b""
if length_high & 0x80 != 0x80:
# Length is only 1 byte
#
# This is the most common case needing a single byte for
# length and type which means we avoid 2 calls to _read_exactly
length_int = length_high
if maybe_msg_type & 0x80 != 0x80:
# Message type is also only 1 byte
msg_type_int = maybe_msg_type
else:
# Message type is longer than 1 byte
msg_type = bytes(init_bytes[2:3])
else:
# This is the most common case with 99% of messages
# needing a single byte for length and type which means
# we avoid 2 calls to readexactly
length = init_bytes[1:2]
msg_type = init_bytes[2:3]
# Length is longer than 1 byte
length = bytes(init_bytes[1:3])
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1)
if add_length is None:
return
length += add_length
length_int = bytes_to_varuint(length)
# Since the length is longer than 1 byte we do not have the
# message type yet.
msg_type = b""
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1)
if add_length is None:
return
length += add_length
# If the we do not have the message type yet because the message
# length was so long it did not fit into the first byte we need
# to read the (rest) of the message type
if msg_type_int is None:
while not msg_type or msg_type[-1] & 0x80 == 0x80:
add_msg_type = self._read_exactly(1)
if add_msg_type is None:
return
msg_type += add_msg_type
msg_type_int = bytes_to_varuint(msg_type)
# If the message length was longer than 1 byte, we need to read the
# message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
add_msg_type = self._read_exactly(1)
if add_msg_type is None:
return
msg_type += add_msg_type
length_int = bytes_to_varuint(bytes(length))
assert length_int is not None
msg_type_int = bytes_to_varuint(bytes(msg_type))
assert msg_type_int is not None
if TYPE_CHECKING:
assert length_int is not None
assert msg_type_int is not None
if length_int == 0:
self._callback_packet(msg_type_int, b"")
@ -205,6 +220,11 @@ class APIPlaintextFrameHelper(APIFrameHelper):
continue
packet_data = self._read_exactly(length_int)
# The packet data is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if packet_data is None:
return
@ -246,6 +266,8 @@ class APINoiseFrameHelper(APIFrameHelper):
"_state",
"_server_name",
"_proto",
"_decrypt",
"_encrypt",
)
def __init__(
@ -262,6 +284,8 @@ class APINoiseFrameHelper(APIFrameHelper):
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._server_name: Optional[str] = None
self._decrypt: Optional[Callable[[bytes], bytes]] = None
self._encrypt: Optional[Callable[[bytes], bytes]] = None
self._setup_proto()
def _set_ready_future_exception(self, exc: Exception) -> None:
@ -291,12 +315,13 @@ class APINoiseFrameHelper(APIFrameHelper):
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending frame: [%s]", frame.hex())
frame_len = len(frame)
try:
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
(frame_len >> 8) & 0xFF,
frame_len & 0xFF,
]
)
self._transport.write(header + frame)
@ -314,17 +339,22 @@ class APINoiseFrameHelper(APIFrameHelper):
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
while self._buffer:
header = self._init_read(3)
assert header is not None, "Buffer should have at least 3 bytes"
if header[0] != 0x01:
if header is None:
return
preamble, msg_size_high, msg_size_low = header
if preamble != 0x01:
self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
)
return
msg_size = (header[1] << 8) | header[2]
frame = self._read_exactly(msg_size)
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
# The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if frame is None:
return
@ -415,44 +445,62 @@ class APINoiseFrameHelper(APIFrameHelper):
return
_LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
None,
)
self._encrypt = partial(
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._ready_future.set_result(None)
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
if self._state != NoiseConnectionState.READY:
raise HandshakeAPIError("Noise connection is not ready")
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
data_len = len(data)
self._write_frame(
self._proto.encrypt(
(
bytes(
[
(type_ >> 8) & 0xFF,
(type_ >> 0) & 0xFF,
(len(data) >> 8) & 0xFF,
(len(data) >> 0) & 0xFF,
]
)
+ data
self._encrypt(
bytes(
[
(type_ >> 8) & 0xFF,
(type_ >> 0) & 0xFF,
(data_len >> 8) & 0xFF,
(data_len >> 0) & 0xFF,
]
)
+ data
)
)
def _handle_frame(self, frame: bytearray) -> None:
"""Handle an incoming frame."""
assert self._proto is not None
msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4:
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
return
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
try:
msg = self._decrypt(bytes(frame))
except InvalidTag as ex:
self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
)
return
data = msg[4 : 4 + data_len]
self._on_pkt(pkt_type, data)
msg_len = len(msg)
if msg_len < 4:
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
return
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
msg_type = (msg_type_high << 8) | msg_type_low
data_len = (data_len_high << 8) | data_len_low
if data_len + 4 != msg_len:
self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
)
return
self._on_pkt(msg_type, msg[4:])
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray

View File

@ -43,6 +43,19 @@ PREAMBLE = b"\x00"
(b"\x42" * 256),
256,
),
(
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
b"\x42",
32768,
),
(
PREAMBLE
+ varuint_to_bytes(32768)
+ varuint_to_bytes(32768)
+ (b"\x42" * 32768),
(b"\x42" * 32768),
32768,
),
],
)
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):