mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-29 13:15:10 +01:00
Reduce protocol overhead (#454)
This commit is contained in:
parent
0626fbe45f
commit
f3f5bd6b55
@ -3,7 +3,8 @@ import base64
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Union, cast
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||||
@ -149,54 +150,68 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
await self._connected_event.wait()
|
await self._connected_event.wait()
|
||||||
|
|
||||||
def data_received(self, data: bytes) -> None:
|
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||||
self._buffer += data
|
self._buffer += data
|
||||||
while len(self._buffer) >= 3:
|
while self._buffer:
|
||||||
# Read preamble, which should always 0x00
|
# Read preamble, which should always 0x00
|
||||||
# Also try to get the length and msg type
|
# Also try to get the length and msg type
|
||||||
# to avoid multiple calls to readexactly
|
# to avoid multiple calls to _read_exactly
|
||||||
init_bytes = self._init_read(3)
|
init_bytes = self._init_read(3)
|
||||||
assert init_bytes is not None, "Buffer should have at least 3 bytes"
|
if init_bytes is None:
|
||||||
if init_bytes[0] != 0x00:
|
return
|
||||||
if init_bytes[0] == 0x01:
|
msg_type_int: Optional[int] = None
|
||||||
|
length_int: Optional[int] = None
|
||||||
|
preamble, length_high, maybe_msg_type = init_bytes
|
||||||
|
if preamble != 0x00:
|
||||||
|
if preamble == 0x01:
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
RequiresEncryptionAPIError("Connection requires encryption")
|
RequiresEncryptionAPIError("Connection requires encryption")
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
ProtocolAPIError(f"Invalid preamble {preamble:02x}")
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if init_bytes[1] & 0x80 == 0x80:
|
if length_high & 0x80 != 0x80:
|
||||||
# Length is longer than 1 byte
|
# Length is only 1 byte
|
||||||
length = init_bytes[1:3]
|
#
|
||||||
msg_type = b""
|
# This is the most common case needing a single byte for
|
||||||
|
# length and type which means we avoid 2 calls to _read_exactly
|
||||||
|
length_int = length_high
|
||||||
|
if maybe_msg_type & 0x80 != 0x80:
|
||||||
|
# Message type is also only 1 byte
|
||||||
|
msg_type_int = maybe_msg_type
|
||||||
else:
|
else:
|
||||||
# This is the most common case with 99% of messages
|
# Message type is longer than 1 byte
|
||||||
# needing a single byte for length and type which means
|
msg_type = bytes(init_bytes[2:3])
|
||||||
# we avoid 2 calls to readexactly
|
else:
|
||||||
length = init_bytes[1:2]
|
# Length is longer than 1 byte
|
||||||
msg_type = init_bytes[2:3]
|
length = bytes(init_bytes[1:3])
|
||||||
|
|
||||||
# If the message is long, we need to read the rest of the length
|
# If the message is long, we need to read the rest of the length
|
||||||
while length[-1] & 0x80 == 0x80:
|
while length[-1] & 0x80 == 0x80:
|
||||||
add_length = self._read_exactly(1)
|
add_length = self._read_exactly(1)
|
||||||
if add_length is None:
|
if add_length is None:
|
||||||
return
|
return
|
||||||
length += add_length
|
length += add_length
|
||||||
|
length_int = bytes_to_varuint(length)
|
||||||
|
# Since the length is longer than 1 byte we do not have the
|
||||||
|
# message type yet.
|
||||||
|
msg_type = b""
|
||||||
|
|
||||||
# If the message length was longer than 1 byte, we need to read the
|
# If the we do not have the message type yet because the message
|
||||||
# message type
|
# length was so long it did not fit into the first byte we need
|
||||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
# to read the (rest) of the message type
|
||||||
|
if msg_type_int is None:
|
||||||
|
while not msg_type or msg_type[-1] & 0x80 == 0x80:
|
||||||
add_msg_type = self._read_exactly(1)
|
add_msg_type = self._read_exactly(1)
|
||||||
if add_msg_type is None:
|
if add_msg_type is None:
|
||||||
return
|
return
|
||||||
msg_type += add_msg_type
|
msg_type += add_msg_type
|
||||||
|
msg_type_int = bytes_to_varuint(msg_type)
|
||||||
|
|
||||||
length_int = bytes_to_varuint(bytes(length))
|
if TYPE_CHECKING:
|
||||||
assert length_int is not None
|
assert length_int is not None
|
||||||
msg_type_int = bytes_to_varuint(bytes(msg_type))
|
|
||||||
assert msg_type_int is not None
|
assert msg_type_int is not None
|
||||||
|
|
||||||
if length_int == 0:
|
if length_int == 0:
|
||||||
@ -205,6 +220,11 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
packet_data = self._read_exactly(length_int)
|
packet_data = self._read_exactly(length_int)
|
||||||
|
# The packet data is not yet available, wait for more data
|
||||||
|
# to arrive before continuing, since callback_packet has not
|
||||||
|
# been called yet the buffer will not be cleared and the next
|
||||||
|
# call to data_received will continue processing the packet
|
||||||
|
# at the start of the frame.
|
||||||
if packet_data is None:
|
if packet_data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -246,6 +266,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
"_state",
|
"_state",
|
||||||
"_server_name",
|
"_server_name",
|
||||||
"_proto",
|
"_proto",
|
||||||
|
"_decrypt",
|
||||||
|
"_encrypt",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -262,6 +284,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._expected_name = expected_name
|
self._expected_name = expected_name
|
||||||
self._state = NoiseConnectionState.HELLO
|
self._state = NoiseConnectionState.HELLO
|
||||||
self._server_name: Optional[str] = None
|
self._server_name: Optional[str] = None
|
||||||
|
self._decrypt: Optional[Callable[[bytes], bytes]] = None
|
||||||
|
self._encrypt: Optional[Callable[[bytes], bytes]] = None
|
||||||
self._setup_proto()
|
self._setup_proto()
|
||||||
|
|
||||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||||
@ -291,12 +315,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
||||||
|
|
||||||
|
frame_len = len(frame)
|
||||||
try:
|
try:
|
||||||
header = bytes(
|
header = bytes(
|
||||||
[
|
[
|
||||||
0x01,
|
0x01,
|
||||||
(len(frame) >> 8) & 0xFF,
|
(frame_len >> 8) & 0xFF,
|
||||||
len(frame) & 0xFF,
|
frame_len & 0xFF,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self._transport.write(header + frame)
|
self._transport.write(header + frame)
|
||||||
@ -314,17 +339,22 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
def data_received(self, data: bytes) -> None:
|
def data_received(self, data: bytes) -> None:
|
||||||
self._buffer += data
|
self._buffer += data
|
||||||
while len(self._buffer) >= 3:
|
while self._buffer:
|
||||||
header = self._init_read(3)
|
header = self._init_read(3)
|
||||||
assert header is not None, "Buffer should have at least 3 bytes"
|
if header is None:
|
||||||
if header[0] != 0x01:
|
return
|
||||||
|
preamble, msg_size_high, msg_size_low = header
|
||||||
|
if preamble != 0x01:
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
msg_size = (header[1] << 8) | header[2]
|
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
|
||||||
frame = self._read_exactly(msg_size)
|
# The complete frame is not yet available, wait for more data
|
||||||
|
# to arrive before continuing, since callback_packet has not
|
||||||
|
# been called yet the buffer will not be cleared and the next
|
||||||
|
# call to data_received will continue processing the packet
|
||||||
|
# at the start of the frame.
|
||||||
if frame is None:
|
if frame is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -415,44 +445,62 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
return
|
return
|
||||||
_LOGGER.debug("Handshake complete")
|
_LOGGER.debug("Handshake complete")
|
||||||
self._state = NoiseConnectionState.READY
|
self._state = NoiseConnectionState.READY
|
||||||
|
noise_protocol = self._proto.noise_protocol
|
||||||
|
self._decrypt = partial(
|
||||||
|
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
self._encrypt = partial(
|
||||||
|
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||||
|
None,
|
||||||
|
)
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packet(self, type_: int, data: bytes) -> None:
|
def write_packet(self, type_: int, data: bytes) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packet to the socket."""
|
||||||
if self._state != NoiseConnectionState.READY:
|
if self._state != NoiseConnectionState.READY:
|
||||||
raise HandshakeAPIError("Noise connection is not ready")
|
raise HandshakeAPIError("Noise connection is not ready")
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._encrypt is not None, "Handshake should be complete"
|
||||||
|
data_len = len(data)
|
||||||
self._write_frame(
|
self._write_frame(
|
||||||
self._proto.encrypt(
|
self._encrypt(
|
||||||
(
|
|
||||||
bytes(
|
bytes(
|
||||||
[
|
[
|
||||||
(type_ >> 8) & 0xFF,
|
(type_ >> 8) & 0xFF,
|
||||||
(type_ >> 0) & 0xFF,
|
(type_ >> 0) & 0xFF,
|
||||||
(len(data) >> 8) & 0xFF,
|
(data_len >> 8) & 0xFF,
|
||||||
(len(data) >> 0) & 0xFF,
|
(data_len >> 0) & 0xFF,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
+ data
|
+ data
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_frame(self, frame: bytearray) -> None:
|
def _handle_frame(self, frame: bytearray) -> None:
|
||||||
"""Handle an incoming frame."""
|
"""Handle an incoming frame."""
|
||||||
assert self._proto is not None
|
if TYPE_CHECKING:
|
||||||
msg = self._proto.decrypt(bytes(frame))
|
assert self._decrypt is not None, "Handshake should be complete"
|
||||||
if len(msg) < 4:
|
try:
|
||||||
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
|
msg = self._decrypt(bytes(frame))
|
||||||
return
|
except InvalidTag as ex:
|
||||||
pkt_type = (msg[0] << 8) | msg[1]
|
|
||||||
data_len = (msg[2] << 8) | msg[3]
|
|
||||||
if data_len + 4 > len(msg):
|
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
data = msg[4 : 4 + data_len]
|
msg_len = len(msg)
|
||||||
self._on_pkt(pkt_type, data)
|
if msg_len < 4:
|
||||||
|
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
|
||||||
|
return
|
||||||
|
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
|
||||||
|
msg_type = (msg_type_high << 8) | msg_type_low
|
||||||
|
data_len = (data_len_high << 8) | data_len_low
|
||||||
|
if data_len + 4 != msg_len:
|
||||||
|
self._handle_error_and_close(
|
||||||
|
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._on_pkt(msg_type, msg[4:])
|
||||||
|
|
||||||
def _handle_closed( # pylint: disable=unused-argument
|
def _handle_closed( # pylint: disable=unused-argument
|
||||||
self, frame: bytearray
|
self, frame: bytearray
|
||||||
|
@ -43,6 +43,19 @@ PREAMBLE = b"\x00"
|
|||||||
(b"\x42" * 256),
|
(b"\x42" * 256),
|
||||||
256,
|
256,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
|
||||||
|
b"\x42",
|
||||||
|
32768,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PREAMBLE
|
||||||
|
+ varuint_to_bytes(32768)
|
||||||
|
+ varuint_to_bytes(32768)
|
||||||
|
+ (b"\x42" * 32768),
|
||||||
|
(b"\x42" * 32768),
|
||||||
|
32768,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||||
|
Loading…
Reference in New Issue
Block a user