mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Reduce memory copy when buffering is not required (#572)
This commit is contained in:
parent
7323ce8987
commit
20ddb972e7
@ -2,6 +2,8 @@
|
||||
import cython
|
||||
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef class APIFrameHelper:
|
||||
|
||||
cdef object _loop
|
||||
@ -10,7 +12,7 @@ cdef class APIFrameHelper:
|
||||
cdef object _transport
|
||||
cdef public object _writer
|
||||
cdef public object _ready_future
|
||||
cdef bytearray _buffer
|
||||
cdef bytes _buffer
|
||||
cdef cython.uint _buffer_len
|
||||
cdef cython.uint _pos
|
||||
cdef object _client_info
|
||||
@ -18,4 +20,9 @@ cdef class APIFrameHelper:
|
||||
cdef object _debug_enabled
|
||||
|
||||
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
|
||||
cdef _read_exactly(self, int length)
|
||||
cdef bytes _read_exactly(self, int length)
|
||||
|
||||
cdef _add_to_buffer(self, bytes data)
|
||||
|
||||
@cython.locals(end_of_frame_pos=cython.uint)
|
||||
cdef _remove_from_buffer(self)
|
@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from typing import Callable, cast
|
||||
from typing import TYPE_CHECKING, Callable, cast
|
||||
|
||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||
|
||||
@ -55,7 +55,7 @@ class APIFrameHelper:
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._ready_future = self._loop.create_future()
|
||||
self._buffer = bytearray()
|
||||
self._buffer: bytes | None = None
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
@ -66,13 +66,48 @@ class APIFrameHelper:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
def _read_exactly(self, length: _int) -> bytearray | None:
|
||||
def _add_to_buffer(self, data: bytes) -> None:
|
||||
"""Add data to the buffer."""
|
||||
if self._buffer_len == 0:
|
||||
# This is the best case scenario, we don't have to copy the data
|
||||
# and can just use the buffer directly. This is the most common
|
||||
# case as well.
|
||||
self._buffer = data
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
# This is the worst case scenario, we have to copy the data
|
||||
# and can't just use the buffer directly. This is also very
|
||||
# uncommon since we usually read the entire frame at once.
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
|
||||
def _remove_from_buffer(self) -> None:
|
||||
"""Remove data from the buffer."""
|
||||
end_of_frame_pos = self._pos
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
if self._buffer_len == 0:
|
||||
# This is the best case scenario, we can just set the buffer to None
|
||||
# and don't have to copy the data. This is the most common case as well.
|
||||
self._buffer = None
|
||||
return
|
||||
if TYPE_CHECKING:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
# This is the worst case scenario, we have to copy the data
|
||||
# and can't just use the buffer directly. This should only happen
|
||||
# when we read multiple frames at once because the event loop
|
||||
# is blocked and we cannot pull the data out of the buffer fast enough.
|
||||
self._buffer = self._buffer[end_of_frame_pos:]
|
||||
|
||||
def _read_exactly(self, length: _int) -> bytes | None:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
new_pos = original_pos + length
|
||||
if self._buffer_len < new_pos:
|
||||
return None
|
||||
self._pos = new_pos
|
||||
if TYPE_CHECKING:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
return self._buffer[original_pos:new_pos]
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
|
@ -3,7 +3,7 @@ import cython
|
||||
from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
cdef object TYPE_CHECKING
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
@ -18,10 +18,15 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
cdef bint _is_ready
|
||||
|
||||
@cython.locals(
|
||||
header=bytearray,
|
||||
header=bytes,
|
||||
preamble=cython.uint,
|
||||
msg_size_high=cython.uint,
|
||||
msg_size_low=cython.uint,
|
||||
end_of_frame_pos=cython.uint,
|
||||
)
|
||||
cpdef data_received(self, bytes data)
|
||||
|
||||
@cython.locals(
|
||||
type_high=cython.uint,
|
||||
type_low=cython.uint
|
||||
)
|
||||
cpdef _handle_frame(self, bytes data)
|
||||
|
@ -139,8 +139,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
await super().perform_handshake(timeout)
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer:
|
||||
self._pos = 0
|
||||
header = self._read_exactly(3)
|
||||
@ -170,9 +169,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._handle_error_and_close(err)
|
||||
finally:
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
self._remove_from_buffer()
|
||||
|
||||
def _send_hello_handshake(self) -> None:
|
||||
"""Send a ClientHello to the server."""
|
||||
@ -198,7 +195,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||
def _handle_hello(self, server_hello: bytes) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
if not server_hello:
|
||||
self._handle_error_and_close(
|
||||
@ -269,7 +266,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
proto.start_handshake()
|
||||
self._proto = proto
|
||||
|
||||
def _handle_handshake(self, msg: bytearray) -> None:
|
||||
def _handle_handshake(self, msg: bytes) -> None:
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
@ -333,12 +330,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
def _handle_frame(self, frame: bytearray) -> None:
|
||||
def _handle_frame(self, frame: bytes) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._decrypt is not None, "Handshake should be complete"
|
||||
try:
|
||||
msg = self._decrypt(bytes(frame))
|
||||
msg = self._decrypt(frame)
|
||||
except InvalidTag as ex:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
||||
@ -348,11 +345,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
# 2 bytes: message type
|
||||
# 2 bytes: message length
|
||||
# N bytes: message data
|
||||
self._on_pkt((msg[0] << 8) | msg[1], msg[4:])
|
||||
type_high = msg[0]
|
||||
type_low = msg[1]
|
||||
self._on_pkt((type_high << 8) | type_low, msg[4:])
|
||||
|
||||
def _handle_closed( # pylint: disable=unused-argument
|
||||
self, frame: bytearray
|
||||
) -> None:
|
||||
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
||||
"""Handle a closed frame."""
|
||||
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|
||||
|
||||
|
@ -3,7 +3,7 @@ import cython
|
||||
from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
cdef object TYPE_CHECKING
|
||||
cdef bint TYPE_CHECKING
|
||||
cdef object WRITE_EXCEPTIONS
|
||||
cdef object bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
@ -12,8 +12,8 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
@cython.locals(
|
||||
msg_type=bytes,
|
||||
length=bytes,
|
||||
init_bytes=bytearray,
|
||||
add_length=bytearray,
|
||||
init_bytes=bytes,
|
||||
add_length=bytes,
|
||||
end_of_frame_pos=cython.uint,
|
||||
length_int=cython.uint,
|
||||
preamble=cython.uint,
|
||||
|
@ -39,8 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
) from err
|
||||
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
@ -80,10 +79,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
msg_type_int = maybe_msg_type
|
||||
else:
|
||||
# Message type is longer than 1 byte
|
||||
msg_type = bytes(init_bytes[2:3])
|
||||
msg_type = init_bytes[2:3]
|
||||
else:
|
||||
# Length is longer than 1 byte
|
||||
length = bytes(init_bytes[1:3])
|
||||
length = init_bytes[1:3]
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
add_length = self._read_exactly(1)
|
||||
@ -112,18 +111,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
if length_int == 0:
|
||||
packet_data = b""
|
||||
else:
|
||||
packet_data_bytearray = self._read_exactly(length_int)
|
||||
maybe_packet_data = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data_bytearray is None:
|
||||
if maybe_packet_data is None:
|
||||
return
|
||||
packet_data = bytes(packet_data_bytearray)
|
||||
packet_data = maybe_packet_data
|
||||
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
self._remove_from_buffer()
|
||||
self._on_pkt(msg_type_int, packet_data)
|
||||
# If we have more data, continue processing
|
||||
|
Loading…
Reference in New Issue
Block a user