mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
Reduce protocol overhead (#454)
This commit is contained in:
parent
0626fbe45f
commit
f3f5bd6b55
@ -3,7 +3,8 @@ import base64
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
||||
|
||||
import async_timeout
|
||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||
@ -149,55 +150,69 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Perform the handshake."""
|
||||
await self._connected_event.wait()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
while len(self._buffer) >= 3:
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to readexactly
|
||||
# to avoid multiple calls to _read_exactly
|
||||
init_bytes = self._init_read(3)
|
||||
assert init_bytes is not None, "Buffer should have at least 3 bytes"
|
||||
if init_bytes[0] != 0x00:
|
||||
if init_bytes[0] == 0x01:
|
||||
if init_bytes is None:
|
||||
return
|
||||
msg_type_int: Optional[int] = None
|
||||
length_int: Optional[int] = None
|
||||
preamble, length_high, maybe_msg_type = init_bytes
|
||||
if preamble != 0x00:
|
||||
if preamble == 0x01:
|
||||
self._handle_error_and_close(
|
||||
RequiresEncryptionAPIError("Connection requires encryption")
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
||||
ProtocolAPIError(f"Invalid preamble {preamble:02x}")
|
||||
)
|
||||
return
|
||||
|
||||
if init_bytes[1] & 0x80 == 0x80:
|
||||
# Length is longer than 1 byte
|
||||
length = init_bytes[1:3]
|
||||
msg_type = b""
|
||||
if length_high & 0x80 != 0x80:
|
||||
# Length is only 1 byte
|
||||
#
|
||||
# This is the most common case needing a single byte for
|
||||
# length and type which means we avoid 2 calls to _read_exactly
|
||||
length_int = length_high
|
||||
if maybe_msg_type & 0x80 != 0x80:
|
||||
# Message type is also only 1 byte
|
||||
msg_type_int = maybe_msg_type
|
||||
else:
|
||||
# Message type is longer than 1 byte
|
||||
msg_type = bytes(init_bytes[2:3])
|
||||
else:
|
||||
# This is the most common case with 99% of messages
|
||||
# needing a single byte for length and type which means
|
||||
# we avoid 2 calls to readexactly
|
||||
length = init_bytes[1:2]
|
||||
msg_type = init_bytes[2:3]
|
||||
# Length is longer than 1 byte
|
||||
length = bytes(init_bytes[1:3])
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
add_length = self._read_exactly(1)
|
||||
if add_length is None:
|
||||
return
|
||||
length += add_length
|
||||
length_int = bytes_to_varuint(length)
|
||||
# Since the length is longer than 1 byte we do not have the
|
||||
# message type yet.
|
||||
msg_type = b""
|
||||
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
add_length = self._read_exactly(1)
|
||||
if add_length is None:
|
||||
return
|
||||
length += add_length
|
||||
# If the we do not have the message type yet because the message
|
||||
# length was so long it did not fit into the first byte we need
|
||||
# to read the (rest) of the message type
|
||||
if msg_type_int is None:
|
||||
while not msg_type or msg_type[-1] & 0x80 == 0x80:
|
||||
add_msg_type = self._read_exactly(1)
|
||||
if add_msg_type is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
|
||||
# If the message length was longer than 1 byte, we need to read the
|
||||
# message type
|
||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||
add_msg_type = self._read_exactly(1)
|
||||
if add_msg_type is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
|
||||
length_int = bytes_to_varuint(bytes(length))
|
||||
assert length_int is not None
|
||||
msg_type_int = bytes_to_varuint(bytes(msg_type))
|
||||
assert msg_type_int is not None
|
||||
if TYPE_CHECKING:
|
||||
assert length_int is not None
|
||||
assert msg_type_int is not None
|
||||
|
||||
if length_int == 0:
|
||||
self._callback_packet(msg_type_int, b"")
|
||||
@ -205,6 +220,11 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
continue
|
||||
|
||||
packet_data = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data is None:
|
||||
return
|
||||
|
||||
@ -246,6 +266,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"_state",
|
||||
"_server_name",
|
||||
"_proto",
|
||||
"_decrypt",
|
||||
"_encrypt",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -262,6 +284,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._expected_name = expected_name
|
||||
self._state = NoiseConnectionState.HELLO
|
||||
self._server_name: Optional[str] = None
|
||||
self._decrypt: Optional[Callable[[bytes], bytes]] = None
|
||||
self._encrypt: Optional[Callable[[bytes], bytes]] = None
|
||||
self._setup_proto()
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
@ -291,12 +315,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
try:
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
(len(frame) >> 8) & 0xFF,
|
||||
len(frame) & 0xFF,
|
||||
(frame_len >> 8) & 0xFF,
|
||||
frame_len & 0xFF,
|
||||
]
|
||||
)
|
||||
self._transport.write(header + frame)
|
||||
@ -314,17 +339,22 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
while len(self._buffer) >= 3:
|
||||
while self._buffer:
|
||||
header = self._init_read(3)
|
||||
assert header is not None, "Buffer should have at least 3 bytes"
|
||||
if header[0] != 0x01:
|
||||
if header is None:
|
||||
return
|
||||
preamble, msg_size_high, msg_size_low = header
|
||||
if preamble != 0x01:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
)
|
||||
return
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = self._read_exactly(msg_size)
|
||||
|
||||
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
|
||||
# The complete frame is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if frame is None:
|
||||
return
|
||||
|
||||
@ -415,44 +445,62 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
return
|
||||
_LOGGER.debug("Handshake complete")
|
||||
self._state = NoiseConnectionState.READY
|
||||
noise_protocol = self._proto.noise_protocol
|
||||
self._decrypt = partial(
|
||||
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._encrypt = partial(
|
||||
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
if self._state != NoiseConnectionState.READY:
|
||||
raise HandshakeAPIError("Noise connection is not ready")
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
data_len = len(data)
|
||||
self._write_frame(
|
||||
self._proto.encrypt(
|
||||
(
|
||||
bytes(
|
||||
[
|
||||
(type_ >> 8) & 0xFF,
|
||||
(type_ >> 0) & 0xFF,
|
||||
(len(data) >> 8) & 0xFF,
|
||||
(len(data) >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ data
|
||||
self._encrypt(
|
||||
bytes(
|
||||
[
|
||||
(type_ >> 8) & 0xFF,
|
||||
(type_ >> 0) & 0xFF,
|
||||
(data_len >> 8) & 0xFF,
|
||||
(data_len >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ data
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_frame(self, frame: bytearray) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(bytes(frame))
|
||||
if len(msg) < 4:
|
||||
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
|
||||
return
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
if TYPE_CHECKING:
|
||||
assert self._decrypt is not None, "Handshake should be complete"
|
||||
try:
|
||||
msg = self._decrypt(bytes(frame))
|
||||
except InvalidTag as ex:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
|
||||
)
|
||||
return
|
||||
data = msg[4 : 4 + data_len]
|
||||
self._on_pkt(pkt_type, data)
|
||||
msg_len = len(msg)
|
||||
if msg_len < 4:
|
||||
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
|
||||
return
|
||||
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
|
||||
msg_type = (msg_type_high << 8) | msg_type_low
|
||||
data_len = (data_len_high << 8) | data_len_low
|
||||
if data_len + 4 != msg_len:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
|
||||
)
|
||||
return
|
||||
self._on_pkt(msg_type, msg[4:])
|
||||
|
||||
def _handle_closed( # pylint: disable=unused-argument
|
||||
self, frame: bytearray
|
||||
|
@ -43,6 +43,19 @@ PREAMBLE = b"\x00"
|
||||
(b"\x42" * 256),
|
||||
256,
|
||||
),
|
||||
(
|
||||
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
|
||||
b"\x42",
|
||||
32768,
|
||||
),
|
||||
(
|
||||
PREAMBLE
|
||||
+ varuint_to_bytes(32768)
|
||||
+ varuint_to_bytes(32768)
|
||||
+ (b"\x42" * 32768),
|
||||
(b"\x42" * 32768),
|
||||
32768,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
|
Loading…
Reference in New Issue
Block a user