mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-12 10:33:57 +01:00
Refactor frame helper to allow sending multiple packets at once (#640)
This commit is contained in:
parent
e6257a8627
commit
d6293d9177
@ -26,3 +26,5 @@ cdef class APIFrameHelper:
|
||||
|
||||
@cython.locals(end_of_frame_pos=cython.uint)
|
||||
cdef _remove_from_buffer(self)
|
||||
|
||||
cpdef write_packets(self, list packets)
|
||||
|
@ -127,8 +127,11 @@ class APIFrameHelper:
|
||||
handshake_handle.cancel()
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||
"""Write a packets to the socket.
|
||||
|
||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||
"""
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
|
@ -32,7 +32,11 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
cpdef _handle_frame(self, bytes data)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
data=bytes,
|
||||
packet=tuple,
|
||||
data_len=cython.uint,
|
||||
frame=bytes,
|
||||
frame_len=cython.uint
|
||||
)
|
||||
cpdef write_packet(self, cython.uint type_, bytes data)
|
||||
cpdef write_packets(self, list packets)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@ -241,16 +241,16 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
psk = self._noise_psk
|
||||
server_name = self._server_name
|
||||
try:
|
||||
psk_bytes = base64.b64decode(psk)
|
||||
psk_bytes = binascii.a2b_base64(psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Malformed PSK {psk}, expected "
|
||||
f"{self._log_name}: Malformed PSK `{psk}`, expected "
|
||||
"base64-encoded value",
|
||||
server_name,
|
||||
)
|
||||
if len(psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}:Malformed PSK {psk}, expected"
|
||||
f"{self._log_name}:Malformed PSK `{psk}`, expected"
|
||||
f" 32-bytes of base64 data",
|
||||
server_name,
|
||||
)
|
||||
@ -304,8 +304,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
def write_packet(self, type_: int_, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||
"""Write a packets to the socket.
|
||||
|
||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||
"""
|
||||
if not self._is_ready:
|
||||
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
||||
|
||||
@ -313,19 +316,32 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
data_len = len(data)
|
||||
data_header = bytes(
|
||||
((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF)
|
||||
)
|
||||
frame = self._encrypt(data_header + data)
|
||||
out: list[bytes] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
for packet in packets:
|
||||
type_: int = packet[0]
|
||||
data: bytes = packet[1]
|
||||
data_len = len(data)
|
||||
data_header = bytes(
|
||||
(
|
||||
(type_ >> 8) & 0xFF,
|
||||
type_ & 0xFF,
|
||||
(data_len >> 8) & 0xFF,
|
||||
data_len & 0xFF,
|
||||
)
|
||||
)
|
||||
frame = self._encrypt(data_header + data)
|
||||
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
if debug_enabled is True:
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
out.append(header)
|
||||
out.append(frame)
|
||||
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
try:
|
||||
self._writer(header + frame)
|
||||
self._writer(b"".join(out))
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
|
@ -28,3 +28,10 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
cpdef data_received(self, bytes data)
|
||||
|
||||
cpdef _error_on_incorrect_preamble(self, object preamble)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
data=bytes,
|
||||
packet=tuple
|
||||
)
|
||||
cpdef write_packets(self, list packets)
|
||||
|
@ -59,20 +59,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
super().connection_made(transport)
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
||||
"""Write a packets to the socket.
|
||||
|
||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
out: list[bytes] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
for packet in packets:
|
||||
type_: int = packet[0]
|
||||
data: bytes = packet[1]
|
||||
out.append(b"\0")
|
||||
out.append(varuint_to_bytes(len(data)))
|
||||
out.append(varuint_to_bytes(type_))
|
||||
out.append(data)
|
||||
if debug_enabled is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Sending plaintext frame %s", self._log_name, data.hex()
|
||||
)
|
||||
|
||||
try:
|
||||
self._writer(data)
|
||||
self._writer(b"".join(out))
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
|
@ -1,5 +1,7 @@
|
||||
import cython
|
||||
|
||||
from ._frame_helper.base cimport APIFrameHelper
|
||||
|
||||
|
||||
cdef dict MESSAGE_TYPE_TO_PROTO
|
||||
cdef dict PROTO_TO_MESSAGE_TYPE
|
||||
@ -47,7 +49,7 @@ cdef class APIConnection:
|
||||
cdef public object on_stop
|
||||
cdef object _on_stop_task
|
||||
cdef public object _socket
|
||||
cdef public object _frame_helper
|
||||
cdef public APIFrameHelper _frame_helper
|
||||
cdef public object api_version
|
||||
cdef public object connection_state
|
||||
cdef dict _message_handlers
|
||||
@ -69,6 +71,8 @@ cdef class APIConnection:
|
||||
|
||||
cpdef send_message(self, object msg)
|
||||
|
||||
cdef send_messages(self, tuple messages)
|
||||
|
||||
@cython.locals(handlers=set, handlers_copy=set)
|
||||
cpdef _process_packet(self, object msg_type_proto, object data)
|
||||
|
||||
@ -89,5 +93,3 @@ cdef class APIConnection:
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
cdef _send_messages(self, tuple messages)
|
||||
|
@ -613,17 +613,11 @@ class APIConnection:
|
||||
connect.password = self._params.password
|
||||
return connect
|
||||
|
||||
def _send_messages(self, messages: tuple[message.Message, ...]) -> None:
|
||||
"""Send a message to the remote.
|
||||
|
||||
Currently this is a wrapper around send_message
|
||||
but may be changed in the future to batch messages
|
||||
together.
|
||||
"""
|
||||
for msg in messages:
|
||||
self.send_message(msg)
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a message to the remote."""
|
||||
self.send_messages((msg,))
|
||||
|
||||
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._handshake_complete:
|
||||
if in_do_connect.get(False):
|
||||
@ -635,23 +629,30 @@ class APIConnection:
|
||||
f"Connection isn't established yet ({self.connection_state})"
|
||||
)
|
||||
|
||||
msg_type = type(msg)
|
||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||
packets: list[tuple[int, bytes]] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
|
||||
if self._debug_enabled() is True:
|
||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)
|
||||
for msg in msgs:
|
||||
msg_type = type(msg)
|
||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||
|
||||
if debug_enabled is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
|
||||
)
|
||||
|
||||
packets.append((message_type, msg.SerializeToString()))
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._frame_helper is not None
|
||||
|
||||
encoded = msg.SerializeToString()
|
||||
try:
|
||||
self._frame_helper.write_packet(message_type, encoded)
|
||||
self._frame_helper.write_packets(packets)
|
||||
except SocketAPIError as err:
|
||||
# If writing packet fails, we don't know what state the frames
|
||||
# are in anymore and we have to close the connection
|
||||
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
||||
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
|
||||
self._report_fatal_error(err)
|
||||
raise
|
||||
|
||||
@ -738,7 +739,7 @@ class APIConnection:
|
||||
# Send the message right away to reduce latency.
|
||||
# This is safe because we are not awaiting between
|
||||
# sending the message and registering the handler
|
||||
self._send_messages(messages)
|
||||
self.send_messages(messages)
|
||||
loop = self._loop
|
||||
# Unsafe to await between sending the message and registering the handler
|
||||
fut: asyncio.Future[None] = loop.create_future()
|
||||
|
@ -475,7 +475,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
assert not writes
|
||||
|
||||
await handshake_task
|
||||
helper.write_packet(1, b"to device")
|
||||
helper.write_packets([(1, b"to device")])
|
||||
encrypted_packet = writes.pop()
|
||||
header = encrypted_packet[0:1]
|
||||
assert header == b"\x01"
|
||||
|
Loading…
Reference in New Issue
Block a user