mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-01 23:22:27 +01:00
Improve performance reassembling fragmented packets (#461)
This commit is contained in:
parent
74ba67d792
commit
b81fe760ba
@ -4,7 +4,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
|
||||
import async_timeout
|
||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||
@ -59,6 +59,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"_transport",
|
||||
"_connected_event",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
)
|
||||
|
||||
@ -73,18 +74,14 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._transport: Optional[asyncio.Transport] = None
|
||||
self._connected_event = asyncio.Event()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
|
||||
def _init_read(self, length: int) -> Optional[bytearray]:
|
||||
"""Start reading a packet from the buffer."""
|
||||
self._pos = 0
|
||||
return self._read_exactly(length)
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
new_pos = original_pos + length
|
||||
if len(self._buffer) < new_pos:
|
||||
if self._buffer_len < new_pos:
|
||||
return None
|
||||
self._pos = new_pos
|
||||
return self._buffer[original_pos:new_pos]
|
||||
@ -126,11 +123,6 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
|
||||
"""Complete reading a packet from the buffer."""
|
||||
del self._buffer[: self._pos]
|
||||
self._on_pkt(type_, data)
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
|
||||
@ -153,11 +145,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to _read_exactly
|
||||
init_bytes = self._init_read(3)
|
||||
self._pos = 0
|
||||
init_bytes = self._read_exactly(3)
|
||||
if init_bytes is None:
|
||||
return
|
||||
msg_type_int: Optional[int] = None
|
||||
@ -216,20 +210,22 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
assert msg_type_int is not None
|
||||
|
||||
if length_int == 0:
|
||||
self._callback_packet(msg_type_int, b"")
|
||||
# If we have more data, continue processing
|
||||
continue
|
||||
packet_data = b""
|
||||
else:
|
||||
packet_data_bytearray = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data_bytearray is None:
|
||||
return
|
||||
packet_data = bytes(packet_data_bytearray)
|
||||
|
||||
packet_data = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data is None:
|
||||
return
|
||||
|
||||
self._callback_packet(msg_type_int, bytes(packet_data))
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
self._on_pkt(msg_type_int, packet_data)
|
||||
# If we have more data, continue processing
|
||||
|
||||
|
||||
@ -340,8 +336,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
while self._buffer:
|
||||
header = self._init_read(3)
|
||||
self._pos = 0
|
||||
header = self._read_exactly(3)
|
||||
if header is None:
|
||||
return
|
||||
preamble, msg_size_high, msg_size_low = header
|
||||
@ -364,7 +362,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._handle_error_and_close(err)
|
||||
finally:
|
||||
del self._buffer[: self._pos]
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
|
||||
def _send_hello(self) -> None:
|
||||
"""Send a ClientHello to the server."""
|
||||
@ -469,9 +469,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
bytes(
|
||||
[
|
||||
(type_ >> 8) & 0xFF,
|
||||
(type_ >> 0) & 0xFF,
|
||||
type_ & 0xFF,
|
||||
(data_len >> 8) & 0xFF,
|
||||
(data_len >> 0) & 0xFF,
|
||||
data_len & 0xFF,
|
||||
]
|
||||
)
|
||||
+ data
|
||||
@ -489,19 +489,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
|
||||
)
|
||||
return
|
||||
msg_len = len(msg)
|
||||
if msg_len < 4:
|
||||
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}"))
|
||||
return
|
||||
msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4]
|
||||
msg_type = (msg_type_high << 8) | msg_type_low
|
||||
data_len = (data_len_high << 8) | data_len_low
|
||||
if data_len + 4 != msg_len:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}")
|
||||
)
|
||||
return
|
||||
self._on_pkt(msg_type, msg[4:])
|
||||
# Message layout is
|
||||
# 2 bytes: message type
|
||||
# 2 bytes: message length
|
||||
# N bytes: message data
|
||||
self._on_pkt((msg[0] << 8) | msg[1], msg[4:])
|
||||
|
||||
def _handle_closed( # pylint: disable=unused-argument
|
||||
self, frame: bytearray
|
||||
|
@ -59,7 +59,7 @@ PREAMBLE = b"\x00"
|
||||
],
|
||||
)
|
||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
for _ in range(5):
|
||||
for _ in range(3):
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
@ -78,6 +78,16 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
assert type_ == pkt_type
|
||||
assert data == pkt_data
|
||||
|
||||
# Make sure we correctly handle fragments
|
||||
for i in range(len(in_bytes)):
|
||||
helper.data_received(in_bytes[i : i + 1])
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
||||
assert type_ == pkt_type
|
||||
assert data == pkt_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_incorrect_key():
|
||||
@ -117,6 +127,46 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
await helper.perform_handshake()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key with fragmented packets."""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
incoming_packets = [
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
in_pkt = bytes.fromhex(pkt)
|
||||
for i in range(len(in_pkt)):
|
||||
helper.data_received(in_pkt[i : i + 1])
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_incorrect_name():
|
||||
"""Test we raise on bad name."""
|
||||
|
Loading…
Reference in New Issue
Block a user