mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-08 19:38:09 +01:00
Reduce memory copy when buffering is not required (#572)
This commit is contained in:
parent
7323ce8987
commit
20ddb972e7
@ -2,6 +2,8 @@
|
|||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint TYPE_CHECKING
|
||||||
|
|
||||||
cdef class APIFrameHelper:
|
cdef class APIFrameHelper:
|
||||||
|
|
||||||
cdef object _loop
|
cdef object _loop
|
||||||
@ -10,7 +12,7 @@ cdef class APIFrameHelper:
|
|||||||
cdef object _transport
|
cdef object _transport
|
||||||
cdef public object _writer
|
cdef public object _writer
|
||||||
cdef public object _ready_future
|
cdef public object _ready_future
|
||||||
cdef bytearray _buffer
|
cdef bytes _buffer
|
||||||
cdef cython.uint _buffer_len
|
cdef cython.uint _buffer_len
|
||||||
cdef cython.uint _pos
|
cdef cython.uint _pos
|
||||||
cdef object _client_info
|
cdef object _client_info
|
||||||
@ -18,4 +20,9 @@ cdef class APIFrameHelper:
|
|||||||
cdef object _debug_enabled
|
cdef object _debug_enabled
|
||||||
|
|
||||||
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
|
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
|
||||||
cdef _read_exactly(self, int length)
|
cdef bytes _read_exactly(self, int length)
|
||||||
|
|
||||||
|
cdef _add_to_buffer(self, bytes data)
|
||||||
|
|
||||||
|
@cython.locals(end_of_frame_pos=cython.uint)
|
||||||
|
cdef _remove_from_buffer(self)
|
@ -4,7 +4,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, cast
|
from typing import TYPE_CHECKING, Callable, cast
|
||||||
|
|
||||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class APIFrameHelper:
|
|||||||
self._transport: asyncio.Transport | None = None
|
self._transport: asyncio.Transport | None = None
|
||||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||||
self._ready_future = self._loop.create_future()
|
self._ready_future = self._loop.create_future()
|
||||||
self._buffer = bytearray()
|
self._buffer: bytes | None = None
|
||||||
self._buffer_len = 0
|
self._buffer_len = 0
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
self._client_info = client_info
|
self._client_info = client_info
|
||||||
@ -66,13 +66,48 @@ class APIFrameHelper:
|
|||||||
if not self._ready_future.done():
|
if not self._ready_future.done():
|
||||||
self._ready_future.set_exception(exc)
|
self._ready_future.set_exception(exc)
|
||||||
|
|
||||||
def _read_exactly(self, length: _int) -> bytearray | None:
|
def _add_to_buffer(self, data: bytes) -> None:
|
||||||
|
"""Add data to the buffer."""
|
||||||
|
if self._buffer_len == 0:
|
||||||
|
# This is the best case scenario, we don't have to copy the data
|
||||||
|
# and can just use the buffer directly. This is the most common
|
||||||
|
# case as well.
|
||||||
|
self._buffer = data
|
||||||
|
else:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._buffer is not None, "Buffer should be set"
|
||||||
|
# This is the worst case scenario, we have to copy the data
|
||||||
|
# and can't just use the buffer directly. This is also very
|
||||||
|
# uncommon since we usually read the entire frame at once.
|
||||||
|
self._buffer += data
|
||||||
|
self._buffer_len += len(data)
|
||||||
|
|
||||||
|
def _remove_from_buffer(self) -> None:
|
||||||
|
"""Remove data from the buffer."""
|
||||||
|
end_of_frame_pos = self._pos
|
||||||
|
self._buffer_len -= end_of_frame_pos
|
||||||
|
if self._buffer_len == 0:
|
||||||
|
# This is the best case scenario, we can just set the buffer to None
|
||||||
|
# and don't have to copy the data. This is the most common case as well.
|
||||||
|
self._buffer = None
|
||||||
|
return
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._buffer is not None, "Buffer should be set"
|
||||||
|
# This is the worst case scenario, we have to copy the data
|
||||||
|
# and can't just use the buffer directly. This should only happen
|
||||||
|
# when we read multiple frames at once because the event loop
|
||||||
|
# is blocked and we cannot pull the data out of the buffer fast enough.
|
||||||
|
self._buffer = self._buffer[end_of_frame_pos:]
|
||||||
|
|
||||||
|
def _read_exactly(self, length: _int) -> bytes | None:
|
||||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||||
original_pos = self._pos
|
original_pos = self._pos
|
||||||
new_pos = original_pos + length
|
new_pos = original_pos + length
|
||||||
if self._buffer_len < new_pos:
|
if self._buffer_len < new_pos:
|
||||||
return None
|
return None
|
||||||
self._pos = new_pos
|
self._pos = new_pos
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._buffer is not None, "Buffer should be set"
|
||||||
return self._buffer[original_pos:new_pos]
|
return self._buffer[original_pos:new_pos]
|
||||||
|
|
||||||
async def perform_handshake(self, timeout: float) -> None:
|
async def perform_handshake(self, timeout: float) -> None:
|
||||||
|
@ -3,7 +3,7 @@ import cython
|
|||||||
from .base cimport APIFrameHelper
|
from .base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
cdef object TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
|
|
||||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||||
|
|
||||||
@ -18,10 +18,15 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
cdef bint _is_ready
|
cdef bint _is_ready
|
||||||
|
|
||||||
@cython.locals(
|
@cython.locals(
|
||||||
header=bytearray,
|
header=bytes,
|
||||||
preamble=cython.uint,
|
preamble=cython.uint,
|
||||||
msg_size_high=cython.uint,
|
msg_size_high=cython.uint,
|
||||||
msg_size_low=cython.uint,
|
msg_size_low=cython.uint,
|
||||||
end_of_frame_pos=cython.uint,
|
|
||||||
)
|
)
|
||||||
cpdef data_received(self, bytes data)
|
cpdef data_received(self, bytes data)
|
||||||
|
|
||||||
|
@cython.locals(
|
||||||
|
type_high=cython.uint,
|
||||||
|
type_low=cython.uint
|
||||||
|
)
|
||||||
|
cpdef _handle_frame(self, bytes data)
|
||||||
|
@ -139,8 +139,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
await super().perform_handshake(timeout)
|
await super().perform_handshake(timeout)
|
||||||
|
|
||||||
def data_received(self, data: bytes) -> None:
|
def data_received(self, data: bytes) -> None:
|
||||||
self._buffer += data
|
self._add_to_buffer(data)
|
||||||
self._buffer_len += len(data)
|
|
||||||
while self._buffer:
|
while self._buffer:
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
header = self._read_exactly(3)
|
header = self._read_exactly(3)
|
||||||
@ -170,9 +169,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self._handle_error_and_close(err)
|
self._handle_error_and_close(err)
|
||||||
finally:
|
finally:
|
||||||
end_of_frame_pos = self._pos
|
self._remove_from_buffer()
|
||||||
del self._buffer[:end_of_frame_pos]
|
|
||||||
self._buffer_len -= end_of_frame_pos
|
|
||||||
|
|
||||||
def _send_hello_handshake(self) -> None:
|
def _send_hello_handshake(self) -> None:
|
||||||
"""Send a ClientHello to the server."""
|
"""Send a ClientHello to the server."""
|
||||||
@ -198,7 +195,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
def _handle_hello(self, server_hello: bytes) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Perform the handshake with the server."""
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
@ -269,7 +266,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
proto.start_handshake()
|
proto.start_handshake()
|
||||||
self._proto = proto
|
self._proto = proto
|
||||||
|
|
||||||
def _handle_handshake(self, msg: bytearray) -> None:
|
def _handle_handshake(self, msg: bytes) -> None:
|
||||||
_LOGGER.debug("Starting handshake...")
|
_LOGGER.debug("Starting handshake...")
|
||||||
if msg[0] != 0:
|
if msg[0] != 0:
|
||||||
explanation = msg[1:].decode()
|
explanation = msg[1:].decode()
|
||||||
@ -333,12 +330,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
def _handle_frame(self, frame: bytearray) -> None:
|
def _handle_frame(self, frame: bytes) -> None:
|
||||||
"""Handle an incoming frame."""
|
"""Handle an incoming frame."""
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._decrypt is not None, "Handshake should be complete"
|
assert self._decrypt is not None, "Handshake should be complete"
|
||||||
try:
|
try:
|
||||||
msg = self._decrypt(bytes(frame))
|
msg = self._decrypt(frame)
|
||||||
except InvalidTag as ex:
|
except InvalidTag as ex:
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
||||||
@ -348,11 +345,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
# 2 bytes: message type
|
# 2 bytes: message type
|
||||||
# 2 bytes: message length
|
# 2 bytes: message length
|
||||||
# N bytes: message data
|
# N bytes: message data
|
||||||
self._on_pkt((msg[0] << 8) | msg[1], msg[4:])
|
type_high = msg[0]
|
||||||
|
type_low = msg[1]
|
||||||
|
self._on_pkt((type_high << 8) | type_low, msg[4:])
|
||||||
|
|
||||||
def _handle_closed( # pylint: disable=unused-argument
|
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
||||||
self, frame: bytearray
|
|
||||||
) -> None:
|
|
||||||
"""Handle a closed frame."""
|
"""Handle a closed frame."""
|
||||||
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import cython
|
|||||||
from .base cimport APIFrameHelper
|
from .base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
cdef object TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
cdef object WRITE_EXCEPTIONS
|
cdef object WRITE_EXCEPTIONS
|
||||||
cdef object bytes_to_varuint, varuint_to_bytes
|
cdef object bytes_to_varuint, varuint_to_bytes
|
||||||
|
|
||||||
@ -12,8 +12,8 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
@cython.locals(
|
@cython.locals(
|
||||||
msg_type=bytes,
|
msg_type=bytes,
|
||||||
length=bytes,
|
length=bytes,
|
||||||
init_bytes=bytearray,
|
init_bytes=bytes,
|
||||||
add_length=bytearray,
|
add_length=bytes,
|
||||||
end_of_frame_pos=cython.uint,
|
end_of_frame_pos=cython.uint,
|
||||||
length_int=cython.uint,
|
length_int=cython.uint,
|
||||||
preamble=cython.uint,
|
preamble=cython.uint,
|
||||||
|
@ -39,8 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
) from err
|
) from err
|
||||||
|
|
||||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||||
self._buffer += data
|
self._add_to_buffer(data)
|
||||||
self._buffer_len += len(data)
|
|
||||||
while self._buffer:
|
while self._buffer:
|
||||||
# Read preamble, which should always 0x00
|
# Read preamble, which should always 0x00
|
||||||
# Also try to get the length and msg type
|
# Also try to get the length and msg type
|
||||||
@ -80,10 +79,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
msg_type_int = maybe_msg_type
|
msg_type_int = maybe_msg_type
|
||||||
else:
|
else:
|
||||||
# Message type is longer than 1 byte
|
# Message type is longer than 1 byte
|
||||||
msg_type = bytes(init_bytes[2:3])
|
msg_type = init_bytes[2:3]
|
||||||
else:
|
else:
|
||||||
# Length is longer than 1 byte
|
# Length is longer than 1 byte
|
||||||
length = bytes(init_bytes[1:3])
|
length = init_bytes[1:3]
|
||||||
# If the message is long, we need to read the rest of the length
|
# If the message is long, we need to read the rest of the length
|
||||||
while length[-1] & 0x80 == 0x80:
|
while length[-1] & 0x80 == 0x80:
|
||||||
add_length = self._read_exactly(1)
|
add_length = self._read_exactly(1)
|
||||||
@ -112,18 +111,16 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
if length_int == 0:
|
if length_int == 0:
|
||||||
packet_data = b""
|
packet_data = b""
|
||||||
else:
|
else:
|
||||||
packet_data_bytearray = self._read_exactly(length_int)
|
maybe_packet_data = self._read_exactly(length_int)
|
||||||
# The packet data is not yet available, wait for more data
|
# The packet data is not yet available, wait for more data
|
||||||
# to arrive before continuing, since callback_packet has not
|
# to arrive before continuing, since callback_packet has not
|
||||||
# been called yet the buffer will not be cleared and the next
|
# been called yet the buffer will not be cleared and the next
|
||||||
# call to data_received will continue processing the packet
|
# call to data_received will continue processing the packet
|
||||||
# at the start of the frame.
|
# at the start of the frame.
|
||||||
if packet_data_bytearray is None:
|
if maybe_packet_data is None:
|
||||||
return
|
return
|
||||||
packet_data = bytes(packet_data_bytearray)
|
packet_data = maybe_packet_data
|
||||||
|
|
||||||
end_of_frame_pos = self._pos
|
self._remove_from_buffer()
|
||||||
del self._buffer[:end_of_frame_pos]
|
|
||||||
self._buffer_len -= end_of_frame_pos
|
|
||||||
self._on_pkt(msg_type_int, packet_data)
|
self._on_pkt(msg_type_int, packet_data)
|
||||||
# If we have more data, continue processing
|
# If we have more data, continue processing
|
||||||
|
Loading…
Reference in New Issue
Block a user