mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-27 03:32:15 +01:00
Refactor frame helper to allow sending multiple packets at once (#640)
This commit is contained in:
parent
e6257a8627
commit
d6293d9177
@ -26,3 +26,5 @@ cdef class APIFrameHelper:
|
|||||||
|
|
||||||
@cython.locals(end_of_frame_pos=cython.uint)
|
@cython.locals(end_of_frame_pos=cython.uint)
|
||||||
cdef _remove_from_buffer(self)
|
cdef _remove_from_buffer(self)
|
||||||
|
|
||||||
|
cpdef write_packets(self, list packets)
|
||||||
|
@ -127,8 +127,11 @@ class APIFrameHelper:
|
|||||||
handshake_handle.cancel()
|
handshake_handle.cancel()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def write_packet(self, type_: int, data: bytes) -> None:
|
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
"""
|
||||||
|
|
||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
"""Handle a new connection."""
|
"""Handle a new connection."""
|
||||||
|
@ -32,7 +32,11 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
cpdef _handle_frame(self, bytes data)
|
cpdef _handle_frame(self, bytes data)
|
||||||
|
|
||||||
@cython.locals(
|
@cython.locals(
|
||||||
|
type_="unsigned int",
|
||||||
|
data=bytes,
|
||||||
|
packet=tuple,
|
||||||
data_len=cython.uint,
|
data_len=cython.uint,
|
||||||
|
frame=bytes,
|
||||||
frame_len=cython.uint
|
frame_len=cython.uint
|
||||||
)
|
)
|
||||||
cpdef write_packet(self, cython.uint type_, bytes data)
|
cpdef write_packets(self, list packets)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -241,16 +241,16 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
psk = self._noise_psk
|
psk = self._noise_psk
|
||||||
server_name = self._server_name
|
server_name = self._server_name
|
||||||
try:
|
try:
|
||||||
psk_bytes = base64.b64decode(psk)
|
psk_bytes = binascii.a2b_base64(psk)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidEncryptionKeyAPIError(
|
raise InvalidEncryptionKeyAPIError(
|
||||||
f"{self._log_name}: Malformed PSK {psk}, expected "
|
f"{self._log_name}: Malformed PSK `{psk}`, expected "
|
||||||
"base64-encoded value",
|
"base64-encoded value",
|
||||||
server_name,
|
server_name,
|
||||||
)
|
)
|
||||||
if len(psk_bytes) != 32:
|
if len(psk_bytes) != 32:
|
||||||
raise InvalidEncryptionKeyAPIError(
|
raise InvalidEncryptionKeyAPIError(
|
||||||
f"{self._log_name}:Malformed PSK {psk}, expected"
|
f"{self._log_name}:Malformed PSK `{psk}`, expected"
|
||||||
f" 32-bytes of base64 data",
|
f" 32-bytes of base64 data",
|
||||||
server_name,
|
server_name,
|
||||||
)
|
)
|
||||||
@ -304,8 +304,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
)
|
)
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packet(self, type_: int_, data: bytes) -> None:
|
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
"""
|
||||||
if not self._is_ready:
|
if not self._is_ready:
|
||||||
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
||||||
|
|
||||||
@ -313,19 +316,32 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
assert self._encrypt is not None, "Handshake should be complete"
|
assert self._encrypt is not None, "Handshake should be complete"
|
||||||
assert self._writer is not None, "Writer is not set"
|
assert self._writer is not None, "Writer is not set"
|
||||||
|
|
||||||
data_len = len(data)
|
out: list[bytes] = []
|
||||||
data_header = bytes(
|
debug_enabled = self._debug_enabled()
|
||||||
((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF)
|
for packet in packets:
|
||||||
)
|
type_: int = packet[0]
|
||||||
frame = self._encrypt(data_header + data)
|
data: bytes = packet[1]
|
||||||
|
data_len = len(data)
|
||||||
|
data_header = bytes(
|
||||||
|
(
|
||||||
|
(type_ >> 8) & 0xFF,
|
||||||
|
type_ & 0xFF,
|
||||||
|
(data_len >> 8) & 0xFF,
|
||||||
|
data_len & 0xFF,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
frame = self._encrypt(data_header + data)
|
||||||
|
|
||||||
if self._debug_enabled():
|
if debug_enabled is True:
|
||||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||||
|
|
||||||
|
frame_len = len(frame)
|
||||||
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
|
out.append(header)
|
||||||
|
out.append(frame)
|
||||||
|
|
||||||
frame_len = len(frame)
|
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
|
||||||
try:
|
try:
|
||||||
self._writer(header + frame)
|
self._writer(b"".join(out))
|
||||||
except WRITE_EXCEPTIONS as err:
|
except WRITE_EXCEPTIONS as err:
|
||||||
raise SocketAPIError(
|
raise SocketAPIError(
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
|
@ -28,3 +28,10 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
cpdef data_received(self, bytes data)
|
cpdef data_received(self, bytes data)
|
||||||
|
|
||||||
cpdef _error_on_incorrect_preamble(self, object preamble)
|
cpdef _error_on_incorrect_preamble(self, object preamble)
|
||||||
|
|
||||||
|
@cython.locals(
|
||||||
|
type_="unsigned int",
|
||||||
|
data=bytes,
|
||||||
|
packet=tuple
|
||||||
|
)
|
||||||
|
cpdef write_packets(self, list packets)
|
||||||
|
@ -59,20 +59,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
super().connection_made(transport)
|
super().connection_made(transport)
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packet(self, type_: int, data: bytes) -> None:
|
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||||
"""Write a packet to the socket.
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
|
||||||
The entire packet must be written in a single call.
|
The entire packet must be written in a single call.
|
||||||
"""
|
"""
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._writer is not None, "Writer should be set"
|
assert self._writer is not None, "Writer should be set"
|
||||||
|
|
||||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
out: list[bytes] = []
|
||||||
if self._debug_enabled():
|
debug_enabled = self._debug_enabled()
|
||||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
for packet in packets:
|
||||||
|
type_: int = packet[0]
|
||||||
|
data: bytes = packet[1]
|
||||||
|
out.append(b"\0")
|
||||||
|
out.append(varuint_to_bytes(len(data)))
|
||||||
|
out.append(varuint_to_bytes(type_))
|
||||||
|
out.append(data)
|
||||||
|
if debug_enabled is True:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s: Sending plaintext frame %s", self._log_name, data.hex()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._writer(data)
|
self._writer(b"".join(out))
|
||||||
except WRITE_EXCEPTIONS as err:
|
except WRITE_EXCEPTIONS as err:
|
||||||
raise SocketAPIError(
|
raise SocketAPIError(
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
from ._frame_helper.base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
cdef dict MESSAGE_TYPE_TO_PROTO
|
cdef dict MESSAGE_TYPE_TO_PROTO
|
||||||
cdef dict PROTO_TO_MESSAGE_TYPE
|
cdef dict PROTO_TO_MESSAGE_TYPE
|
||||||
@ -47,7 +49,7 @@ cdef class APIConnection:
|
|||||||
cdef public object on_stop
|
cdef public object on_stop
|
||||||
cdef object _on_stop_task
|
cdef object _on_stop_task
|
||||||
cdef public object _socket
|
cdef public object _socket
|
||||||
cdef public object _frame_helper
|
cdef public APIFrameHelper _frame_helper
|
||||||
cdef public object api_version
|
cdef public object api_version
|
||||||
cdef public object connection_state
|
cdef public object connection_state
|
||||||
cdef dict _message_handlers
|
cdef dict _message_handlers
|
||||||
@ -69,6 +71,8 @@ cdef class APIConnection:
|
|||||||
|
|
||||||
cpdef send_message(self, object msg)
|
cpdef send_message(self, object msg)
|
||||||
|
|
||||||
|
cdef send_messages(self, tuple messages)
|
||||||
|
|
||||||
@cython.locals(handlers=set, handlers_copy=set)
|
@cython.locals(handlers=set, handlers_copy=set)
|
||||||
cpdef _process_packet(self, object msg_type_proto, object data)
|
cpdef _process_packet(self, object msg_type_proto, object data)
|
||||||
|
|
||||||
@ -89,5 +93,3 @@ cdef class APIConnection:
|
|||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
cdef _send_messages(self, tuple messages)
|
|
||||||
|
@ -613,17 +613,11 @@ class APIConnection:
|
|||||||
connect.password = self._params.password
|
connect.password = self._params.password
|
||||||
return connect
|
return connect
|
||||||
|
|
||||||
def _send_messages(self, messages: tuple[message.Message, ...]) -> None:
|
|
||||||
"""Send a message to the remote.
|
|
||||||
|
|
||||||
Currently this is a wrapper around send_message
|
|
||||||
but may be changed in the future to batch messages
|
|
||||||
together.
|
|
||||||
"""
|
|
||||||
for msg in messages:
|
|
||||||
self.send_message(msg)
|
|
||||||
|
|
||||||
def send_message(self, msg: message.Message) -> None:
|
def send_message(self, msg: message.Message) -> None:
|
||||||
|
"""Send a message to the remote."""
|
||||||
|
self.send_messages((msg,))
|
||||||
|
|
||||||
|
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
|
||||||
"""Send a protobuf message to the remote."""
|
"""Send a protobuf message to the remote."""
|
||||||
if not self._handshake_complete:
|
if not self._handshake_complete:
|
||||||
if in_do_connect.get(False):
|
if in_do_connect.get(False):
|
||||||
@ -635,23 +629,30 @@ class APIConnection:
|
|||||||
f"Connection isn't established yet ({self.connection_state})"
|
f"Connection isn't established yet ({self.connection_state})"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg_type = type(msg)
|
packets: list[tuple[int, bytes]] = []
|
||||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
debug_enabled = self._debug_enabled()
|
||||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
|
||||||
|
|
||||||
if self._debug_enabled() is True:
|
for msg in msgs:
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
|
msg_type = type(msg)
|
||||||
|
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||||
|
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||||
|
|
||||||
|
if debug_enabled is True:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
|
||||||
|
)
|
||||||
|
|
||||||
|
packets.append((message_type, msg.SerializeToString()))
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._frame_helper is not None
|
assert self._frame_helper is not None
|
||||||
|
|
||||||
encoded = msg.SerializeToString()
|
|
||||||
try:
|
try:
|
||||||
self._frame_helper.write_packet(message_type, encoded)
|
self._frame_helper.write_packets(packets)
|
||||||
except SocketAPIError as err:
|
except SocketAPIError as err:
|
||||||
# If writing packet fails, we don't know what state the frames
|
# If writing packet fails, we don't know what state the frames
|
||||||
# are in anymore and we have to close the connection
|
# are in anymore and we have to close the connection
|
||||||
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
|
||||||
self._report_fatal_error(err)
|
self._report_fatal_error(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -738,7 +739,7 @@ class APIConnection:
|
|||||||
# Send the message right away to reduce latency.
|
# Send the message right away to reduce latency.
|
||||||
# This is safe because we are not awaiting between
|
# This is safe because we are not awaiting between
|
||||||
# sending the message and registering the handler
|
# sending the message and registering the handler
|
||||||
self._send_messages(messages)
|
self.send_messages(messages)
|
||||||
loop = self._loop
|
loop = self._loop
|
||||||
# Unsafe to await between sending the message and registering the handler
|
# Unsafe to await between sending the message and registering the handler
|
||||||
fut: asyncio.Future[None] = loop.create_future()
|
fut: asyncio.Future[None] = loop.create_future()
|
||||||
|
@ -475,7 +475,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
assert not writes
|
assert not writes
|
||||||
|
|
||||||
await handshake_task
|
await handshake_task
|
||||||
helper.write_packet(1, b"to device")
|
helper.write_packets([(1, b"to device")])
|
||||||
encrypted_packet = writes.pop()
|
encrypted_packet = writes.pop()
|
||||||
header = encrypted_packet[0:1]
|
header = encrypted_packet[0:1]
|
||||||
assert header == b"\x01"
|
assert header == b"\x01"
|
||||||
|
Loading…
Reference in New Issue
Block a user