Refactor frame helper to allow sending multiple packets at once (#640)

This commit is contained in:
J. Nick Koston 2023-11-16 10:31:02 -06:00 committed by GitHub
parent e6257a8627
commit d6293d9177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 95 additions and 48 deletions

View File

@ -26,3 +26,5 @@ cdef class APIFrameHelper:
@cython.locals(end_of_frame_pos=cython.uint)
cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets)

View File

@ -127,8 +127,11 @@ class APIFrameHelper:
handshake_handle.cancel()
@abstractmethod
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""

View File

@ -32,7 +32,11 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
cpdef _handle_frame(self, bytes data)
@cython.locals(
type_="unsigned int",
data=bytes,
packet=tuple,
data_len=cython.uint,
frame=bytes,
frame_len=cython.uint
)
cpdef write_packet(self, cython.uint type_, bytes data)
cpdef write_packets(self, list packets)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import base64
import binascii
import logging
from enum import Enum
from functools import partial
@ -241,16 +241,16 @@ class APINoiseFrameHelper(APIFrameHelper):
psk = self._noise_psk
server_name = self._server_name
try:
psk_bytes = base64.b64decode(psk)
psk_bytes = binascii.a2b_base64(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}: Malformed PSK {psk}, expected "
f"{self._log_name}: Malformed PSK `{psk}`, expected "
"base64-encoded value",
server_name,
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}:Malformed PSK {psk}, expected"
f"{self._log_name}:Malformed PSK `{psk}`, expected"
f" 32-bytes of base64 data",
server_name,
)
@ -304,8 +304,11 @@ class APINoiseFrameHelper(APIFrameHelper):
)
self._ready_future.set_result(None)
def write_packet(self, type_: int_, data: bytes) -> None:
"""Write a packet to the socket."""
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
if not self._is_ready:
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
@ -313,19 +316,32 @@ class APINoiseFrameHelper(APIFrameHelper):
assert self._encrypt is not None, "Handshake should be complete"
assert self._writer is not None, "Writer is not set"
data_len = len(data)
data_header = bytes(
((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF)
)
frame = self._encrypt(data_header + data)
out: list[bytes] = []
debug_enabled = self._debug_enabled()
for packet in packets:
type_: int = packet[0]
data: bytes = packet[1]
data_len = len(data)
data_header = bytes(
(
(type_ >> 8) & 0xFF,
type_ & 0xFF,
(data_len >> 8) & 0xFF,
data_len & 0xFF,
)
)
frame = self._encrypt(data_header + data)
if self._debug_enabled():
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
if debug_enabled is True:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
out.append(header)
out.append(frame)
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
self._writer(b"".join(out))
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"

View File

@ -28,3 +28,10 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
cpdef data_received(self, bytes data)
cpdef _error_on_incorrect_preamble(self, object preamble)
@cython.locals(
type_="unsigned int",
data=bytes,
packet=tuple
)
cpdef write_packets(self, list packets)

View File

@ -59,20 +59,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
super().connection_made(transport)
self._ready_future.set_result(None)
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket.
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
The entire packet must be written in a single call.
"""
if TYPE_CHECKING:
assert self._writer is not None, "Writer should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
if self._debug_enabled():
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
out: list[bytes] = []
debug_enabled = self._debug_enabled()
for packet in packets:
type_: int = packet[0]
data: bytes = packet[1]
out.append(b"\0")
out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_))
out.append(data)
if debug_enabled is True:
_LOGGER.debug(
"%s: Sending plaintext frame %s", self._log_name, data.hex()
)
try:
self._writer(data)
self._writer(b"".join(out))
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"

View File

@ -1,5 +1,7 @@
import cython
from ._frame_helper.base cimport APIFrameHelper
cdef dict MESSAGE_TYPE_TO_PROTO
cdef dict PROTO_TO_MESSAGE_TYPE
@ -47,7 +49,7 @@ cdef class APIConnection:
cdef public object on_stop
cdef object _on_stop_task
cdef public object _socket
cdef public object _frame_helper
cdef public APIFrameHelper _frame_helper
cdef public object api_version
cdef public object connection_state
cdef dict _message_handlers
@ -69,6 +71,8 @@ cdef class APIConnection:
cpdef send_message(self, object msg)
cdef send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set)
cpdef _process_packet(self, object msg_type_proto, object data)
@ -89,5 +93,3 @@ cdef class APIConnection:
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
cdef _send_messages(self, tuple messages)

View File

@ -613,17 +613,11 @@ class APIConnection:
connect.password = self._params.password
return connect
def _send_messages(self, messages: tuple[message.Message, ...]) -> None:
"""Send a message to the remote.
Currently this is a wrapper around send_message
but may be changed in the future to batch messages
together.
"""
for msg in messages:
self.send_message(msg)
def send_message(self, msg: message.Message) -> None:
"""Send a message to the remote."""
self.send_messages((msg,))
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
"""Send a protobuf message to the remote."""
if not self._handshake_complete:
if in_do_connect.get(False):
@ -635,23 +629,30 @@ class APIConnection:
f"Connection isn't established yet ({self.connection_state})"
)
msg_type = type(msg)
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")
packets: list[tuple[int, bytes]] = []
debug_enabled = self._debug_enabled()
if self._debug_enabled() is True:
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
for msg in msgs:
msg_type = type(msg)
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")
if debug_enabled is True:
_LOGGER.debug(
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
)
packets.append((message_type, msg.SerializeToString()))
if TYPE_CHECKING:
assert self._frame_helper is not None
encoded = msg.SerializeToString()
try:
self._frame_helper.write_packet(message_type, encoded)
self._frame_helper.write_packets(packets)
except SocketAPIError as err:
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
self._report_fatal_error(err)
raise
@ -738,7 +739,7 @@ class APIConnection:
# Send the message right away to reduce latency.
# This is safe because we are not awaiting between
# sending the message and registering the handler
self._send_messages(messages)
self.send_messages(messages)
loop = self._loop
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()

View File

@ -475,7 +475,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert not writes
await handshake_task
helper.write_packet(1, b"to device")
helper.write_packets([(1, b"to device")])
encrypted_packet = writes.pop()
header = encrypted_packet[0:1]
assert header == b"\x01"