2023-07-19 22:33:28 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-07-21 10:11:04 +02:00
|
|
|
import asyncio
|
2023-10-17 05:24:03 +02:00
|
|
|
from functools import lru_cache
|
2023-07-19 22:33:28 +02:00
|
|
|
from typing import TYPE_CHECKING
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-11-18 22:10:40 +01:00
|
|
|
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
|
|
|
from .base import APIFrameHelper
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-10-17 05:24:03 +02:00
|
|
|
_int = int
|
|
|
|
_bytes = bytes
|
|
|
|
|
|
|
|
|
|
|
|
def _varuint_to_bytes(value: _int) -> bytes:
|
|
|
|
"""Convert a varuint to bytes."""
|
|
|
|
if value <= 0x7F:
|
|
|
|
return bytes((value,))
|
|
|
|
|
|
|
|
result = []
|
|
|
|
while value:
|
|
|
|
temp = value & 0x7F
|
|
|
|
value >>= 7
|
|
|
|
if value:
|
|
|
|
result.append(temp | 0x80)
|
|
|
|
else:
|
|
|
|
result.append(temp)
|
|
|
|
|
|
|
|
return bytes(result)
|
|
|
|
|
|
|
|
|
|
|
|
_cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
|
|
|
|
varuint_to_bytes = _cached_varuint_to_bytes
|
|
|
|
|
|
|
|
|
|
|
|
def _bytes_to_varuint(value: _bytes) -> _int | None:
|
|
|
|
"""Convert bytes to a varuint."""
|
|
|
|
result = 0
|
|
|
|
bitpos = 0
|
|
|
|
for val in value:
|
|
|
|
result |= (val & 0x7F) << bitpos
|
|
|
|
if (val & 0x80) == 0:
|
|
|
|
return result
|
|
|
|
bitpos += 7
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
|
|
|
|
bytes_to_varuint = _cached_bytes_to_varuint
|
|
|
|
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
|
|
|
"""Frame helper for plaintext API connections."""
|
|
|
|
|
2023-07-21 10:11:04 +02:00
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
|
"""Handle a new connection."""
|
|
|
|
super().connection_made(transport)
|
|
|
|
self._ready_future.set_result(None)
|
|
|
|
|
2023-11-16 17:31:02 +01:00
|
|
|
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
|
|
|
"""Write a packets to the socket.
|
|
|
|
|
|
|
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
The entire packet must be written in a single call.
|
|
|
|
"""
|
2023-11-16 17:31:02 +01:00
|
|
|
out: list[bytes] = []
|
|
|
|
for packet in packets:
|
|
|
|
type_: int = packet[0]
|
|
|
|
data: bytes = packet[1]
|
|
|
|
out.append(b"\0")
|
|
|
|
out.append(varuint_to_bytes(len(data)))
|
|
|
|
out.append(varuint_to_bytes(type_))
|
|
|
|
out.append(data)
|
2023-11-18 22:10:40 +01:00
|
|
|
|
|
|
|
self._write_bytes(b"".join(out))
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-11-09 17:20:42 +01:00
|
|
|
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
2023-11-17 00:50:54 +01:00
|
|
|
self, data: bytes | bytearray | memoryview
|
2023-11-09 17:20:42 +01:00
|
|
|
) -> None:
|
2023-10-13 03:17:46 +02:00
|
|
|
self._add_to_buffer(data)
|
2023-07-18 21:28:56 +02:00
|
|
|
while self._buffer:
|
|
|
|
# Read preamble, which should always 0x00
|
|
|
|
# Also try to get the length and msg type
|
|
|
|
# to avoid multiple calls to _read_exactly
|
|
|
|
self._pos = 0
|
2023-11-09 17:20:42 +01:00
|
|
|
if (init_bytes := self._read_exactly(3)) is None:
|
2023-07-18 21:28:56 +02:00
|
|
|
return
|
2023-07-19 22:33:28 +02:00
|
|
|
msg_type_int: int | None = None
|
2023-10-12 20:12:39 +02:00
|
|
|
length_int = 0
|
|
|
|
preamble = init_bytes[0]
|
|
|
|
length_high = init_bytes[1]
|
|
|
|
maybe_msg_type = init_bytes[2]
|
2023-07-18 21:28:56 +02:00
|
|
|
if preamble != 0x00:
|
2023-11-09 18:12:00 +01:00
|
|
|
self._error_on_incorrect_preamble(preamble)
|
2023-07-18 21:28:56 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
if length_high & 0x80 != 0x80:
|
|
|
|
# Length is only 1 byte
|
|
|
|
#
|
|
|
|
# This is the most common case needing a single byte for
|
|
|
|
# length and type which means we avoid 2 calls to _read_exactly
|
|
|
|
length_int = length_high
|
|
|
|
if maybe_msg_type & 0x80 != 0x80:
|
|
|
|
# Message type is also only 1 byte
|
|
|
|
msg_type_int = maybe_msg_type
|
|
|
|
else:
|
|
|
|
# Message type is longer than 1 byte
|
2023-10-13 03:17:46 +02:00
|
|
|
msg_type = init_bytes[2:3]
|
2023-07-18 21:28:56 +02:00
|
|
|
else:
|
|
|
|
# Length is longer than 1 byte
|
2023-10-13 03:17:46 +02:00
|
|
|
length = init_bytes[1:3]
|
2023-07-18 21:28:56 +02:00
|
|
|
# If the message is long, we need to read the rest of the length
|
|
|
|
while length[-1] & 0x80 == 0x80:
|
2023-11-09 17:20:42 +01:00
|
|
|
if (add_length := self._read_exactly(1)) is None:
|
2023-07-18 21:28:56 +02:00
|
|
|
return
|
|
|
|
length += add_length
|
2023-10-12 20:12:39 +02:00
|
|
|
length_int = bytes_to_varuint(length) or 0
|
2023-07-18 21:28:56 +02:00
|
|
|
# Since the length is longer than 1 byte we do not have the
|
|
|
|
# message type yet.
|
2023-11-09 17:20:42 +01:00
|
|
|
if (msg_type_byte := self._read_exactly(1)) is None:
|
|
|
|
return
|
|
|
|
msg_type = msg_type_byte
|
|
|
|
if msg_type[-1] & 0x80 != 0x80:
|
|
|
|
# Message type is only 1 byte
|
|
|
|
msg_type_int = msg_type[0]
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
# If the we do not have the message type yet because the message
|
|
|
|
# length was so long it did not fit into the first byte we need
|
|
|
|
# to read the (rest) of the message type
|
|
|
|
if msg_type_int is None:
|
2023-11-09 17:20:42 +01:00
|
|
|
while msg_type[-1] & 0x80 == 0x80:
|
|
|
|
if (add_msg_type := self._read_exactly(1)) is None:
|
2023-07-18 21:28:56 +02:00
|
|
|
return
|
|
|
|
msg_type += add_msg_type
|
|
|
|
msg_type_int = bytes_to_varuint(msg_type)
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
assert msg_type_int is not None
|
|
|
|
|
|
|
|
if length_int == 0:
|
|
|
|
packet_data = b""
|
|
|
|
else:
|
|
|
|
# The packet data is not yet available, wait for more data
|
|
|
|
# to arrive before continuing, since callback_packet has not
|
|
|
|
# been called yet the buffer will not be cleared and the next
|
|
|
|
# call to data_received will continue processing the packet
|
|
|
|
# at the start of the frame.
|
2023-11-09 17:20:42 +01:00
|
|
|
if (maybe_packet_data := self._read_exactly(length_int)) is None:
|
2023-07-18 21:28:56 +02:00
|
|
|
return
|
2023-10-13 03:17:46 +02:00
|
|
|
packet_data = maybe_packet_data
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-10-13 03:17:46 +02:00
|
|
|
self._remove_from_buffer()
|
2023-11-16 19:24:50 +01:00
|
|
|
self._connection.process_packet(msg_type_int, packet_data)
|
2023-07-18 21:28:56 +02:00
|
|
|
# If we have more data, continue processing
|
2023-11-09 18:12:00 +01:00
|
|
|
|
|
|
|
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
|
|
|
"""Handle an incorrect preamble."""
|
|
|
|
if preamble == 0x01:
|
|
|
|
self._handle_error_and_close(
|
|
|
|
RequiresEncryptionAPIError(
|
|
|
|
f"{self._log_name}: Connection requires encryption"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return
|
|
|
|
self._handle_error_and_close(
|
|
|
|
ProtocolAPIError(f"{self._log_name}: Invalid preamble {preamble:02x}")
|
|
|
|
)
|