mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-12 10:33:57 +01:00
Fix handling incoming data from protractor event loop (#642)
This commit is contained in:
parent
a60e54d438
commit
f94ddf8e6a
@ -23,7 +23,8 @@ cdef class APIFrameHelper:
|
||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||
cdef bytes _read_exactly(self, int length)
|
||||
|
||||
cdef _add_to_buffer(self, bytes data)
|
||||
@cython.locals(bytes_data=bytes)
|
||||
cdef _add_to_buffer(self, object data)
|
||||
|
||||
@cython.locals(end_of_frame_pos="unsigned int")
|
||||
cdef _remove_from_buffer(self)
|
||||
|
@ -66,21 +66,29 @@ class APIFrameHelper:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
def _add_to_buffer(self, data: bytes) -> None:
|
||||
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Add data to the buffer."""
|
||||
# This should not be isinstance(data, bytes) because we want to
|
||||
# to explicitly check for bytes and not for subclasses of bytes
|
||||
if type(data) is not bytes: # pylint: disable=unidiomatic-typecheck
|
||||
# Protractor sends a bytearray, so we need to convert it to bytes
|
||||
# https://github.com/esphome/issues/issues/5117
|
||||
bytes_data = bytes(data)
|
||||
else:
|
||||
bytes_data = data
|
||||
if self._buffer_len == 0:
|
||||
# This is the best case scenario, we don't have to copy the data
|
||||
# and can just use the buffer directly. This is the most common
|
||||
# case as well.
|
||||
self._buffer = data
|
||||
self._buffer = bytes_data
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
# This is the worst case scenario, we have to copy the data
|
||||
# This is the worst case scenario, we have to copy the bytes_data
|
||||
# and can't just use the buffer directly. This is also very
|
||||
# uncommon since we usually read the entire frame at once.
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
self._buffer += bytes_data
|
||||
self._buffer_len += len(bytes_data)
|
||||
|
||||
def _remove_from_buffer(self) -> None:
|
||||
"""Remove data from the buffer."""
|
||||
|
@ -6,17 +6,21 @@ from .base cimport APIFrameHelper
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef unsigned int NOISE_STATE_HELLO
|
||||
cdef unsigned int NOISE_STATE_HANDSHAKE
|
||||
cdef unsigned int NOISE_STATE_READY
|
||||
cdef unsigned int NOISE_STATE_CLOSED
|
||||
|
||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
cdef object _noise_psk
|
||||
cdef object _expected_name
|
||||
cdef object _state
|
||||
cdef unsigned int _state
|
||||
cdef object _dispatch
|
||||
cdef object _server_name
|
||||
cdef object _proto
|
||||
cdef object _decrypt
|
||||
cdef object _encrypt
|
||||
cdef bint _is_ready
|
||||
|
||||
@cython.locals(
|
||||
header=bytes,
|
||||
@ -24,13 +28,19 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
msg_size_high=cython.uint,
|
||||
msg_size_low=cython.uint,
|
||||
)
|
||||
cpdef data_received(self, bytes data)
|
||||
cpdef data_received(self, object data)
|
||||
|
||||
@cython.locals(
|
||||
type_high=cython.uint,
|
||||
type_low=cython.uint
|
||||
)
|
||||
cpdef _handle_frame(self, bytes data)
|
||||
cdef _handle_frame(self, bytes frame)
|
||||
|
||||
cdef _handle_hello(self, bytes server_hello)
|
||||
|
||||
cdef _handle_handshake(self, bytes msg)
|
||||
|
||||
cdef _handle_closed(self, bytes frame)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import logging
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
@ -53,13 +52,17 @@ class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
|
||||
ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
|
||||
|
||||
|
||||
class NoiseConnectionState(Enum):
|
||||
"""Noise connection state."""
|
||||
|
||||
HELLO = 1
|
||||
HANDSHAKE = 2
|
||||
READY = 3
|
||||
CLOSED = 4
|
||||
# This is effectively an enum but we don't want to use an enum
|
||||
# because we have a simple dispatch in the data_received method
|
||||
# that would be more complicated with an enum and we want to add
|
||||
# cdefs for each different state so we have a good test for each
|
||||
# state receiving data since we found that the protractor event
|
||||
# loop will send use a bytearray instead of bytes was not handled
|
||||
# correctly.
|
||||
NOISE_STATE_HELLO = 1
|
||||
NOISE_STATE_HANDSHAKE = 2
|
||||
NOISE_STATE_READY = 3
|
||||
NOISE_STATE_CLOSED = 4
|
||||
|
||||
|
||||
NOISE_HELLO = b"\x01\x00\x00"
|
||||
@ -79,7 +82,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"_proto",
|
||||
"_decrypt",
|
||||
"_encrypt",
|
||||
"_is_ready",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -94,18 +96,11 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
super().__init__(connection, client_info, log_name)
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._set_state(NoiseConnectionState.HELLO)
|
||||
self._state = NOISE_STATE_HELLO
|
||||
self._server_name: str | None = None
|
||||
self._decrypt: Callable[[bytes], bytes] | None = None
|
||||
self._encrypt: Callable[[bytes], bytes] | None = None
|
||||
self._setup_proto()
|
||||
self._is_ready = False
|
||||
|
||||
def _set_state(self, state: NoiseConnectionState) -> None:
|
||||
"""Set the current state."""
|
||||
self._state = state
|
||||
self._is_ready = state == NoiseConnectionState.READY
|
||||
self._dispatch = self.STATE_TO_CALLABLE[state]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
@ -115,7 +110,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._set_ready_future_exception(
|
||||
APIConnectionError(f"{self._log_name}: Connection closed")
|
||||
)
|
||||
self._set_state(NoiseConnectionState.CLOSED)
|
||||
self._state = NOISE_STATE_CLOSED
|
||||
super().close()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
@ -124,10 +119,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
"""Handle an error, and provide a good message when during hello."""
|
||||
if (
|
||||
isinstance(exc, ConnectionResetError)
|
||||
and self._state == NoiseConnectionState.HELLO
|
||||
):
|
||||
if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO:
|
||||
original_exc = exc
|
||||
exc = HandshakeAPIError(
|
||||
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
||||
@ -142,7 +134,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._send_hello_handshake()
|
||||
await super().perform_handshake(timeout)
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer:
|
||||
self._pos = 0
|
||||
@ -168,7 +160,14 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
return
|
||||
|
||||
try:
|
||||
self._dispatch(self, frame)
|
||||
if self._state == NOISE_STATE_READY:
|
||||
self._handle_frame(frame)
|
||||
elif self._state == NOISE_STATE_HELLO:
|
||||
self._handle_hello(frame)
|
||||
elif self._state == NOISE_STATE_HANDSHAKE:
|
||||
self._handle_handshake(frame)
|
||||
else:
|
||||
self._handle_closed(frame)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._handle_error_and_close(err)
|
||||
finally:
|
||||
@ -236,7 +235,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
return
|
||||
|
||||
self._set_state(NoiseConnectionState.HANDSHAKE)
|
||||
self._state = NOISE_STATE_HANDSHAKE
|
||||
|
||||
def _decode_noise_psk(self) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
@ -294,7 +293,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._handle_error_and_close(ex)
|
||||
return
|
||||
_LOGGER.debug("Handshake complete")
|
||||
self._set_state(NoiseConnectionState.READY)
|
||||
self._state = NOISE_STATE_READY
|
||||
noise_protocol = self._proto.noise_protocol
|
||||
self._decrypt = partial(
|
||||
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
|
||||
@ -311,7 +310,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||
"""
|
||||
if not self._is_ready:
|
||||
if self._state != NOISE_STATE_READY:
|
||||
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -371,10 +370,3 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
||||
"""Handle a closed frame."""
|
||||
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|
||||
|
||||
STATE_TO_CALLABLE = {
|
||||
NoiseConnectionState.HELLO: _handle_hello,
|
||||
NoiseConnectionState.HANDSHAKE: _handle_handshake,
|
||||
NoiseConnectionState.READY: _handle_frame,
|
||||
NoiseConnectionState.CLOSED: _handle_closed,
|
||||
}
|
||||
|
@ -26,9 +26,9 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
length_high=cython.uint,
|
||||
maybe_msg_type=cython.uint
|
||||
)
|
||||
cpdef data_received(self, bytes data)
|
||||
cpdef data_received(self, object data)
|
||||
|
||||
cpdef _error_on_incorrect_preamble(self, object preamble)
|
||||
cdef _error_on_incorrect_preamble(self, object preamble)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
|
@ -91,7 +91,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
) from err
|
||||
|
||||
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
||||
self, data: bytes
|
||||
self, data: bytes | bytearray | memoryview
|
||||
) -> None:
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer:
|
||||
|
@ -25,6 +25,7 @@ from aioesphomeapi.core import (
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
|
||||
@ -67,7 +68,6 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
) from err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"in_bytes, pkt_data, pkt_type",
|
||||
[
|
||||
@ -115,7 +115,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
def test_plaintext_frame_helper(
|
||||
in_bytes: bytes, pkt_data: bytes, pkt_type: int
|
||||
) -> None:
|
||||
for _ in range(3):
|
||||
connection, packets = _make_mock_connection()
|
||||
helper = APIPlaintextFrameHelper(
|
||||
@ -141,6 +143,90 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
assert data == pkt_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"byte_type",
|
||||
(bytes, bytearray, memoryview),
|
||||
)
|
||||
def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
|
||||
"""Test the plaintext frame helper with the protractor event loop.
|
||||
|
||||
With the protractor event loop, data_received is called with a bytearray
|
||||
instead of bytes.
|
||||
|
||||
https://github.com/esphome/issues/issues/5117
|
||||
"""
|
||||
for _ in range(3):
|
||||
connection, packets = _make_mock_connection()
|
||||
helper = APIPlaintextFrameHelper(
|
||||
connection=connection, client_info="my client", log_name="test"
|
||||
)
|
||||
in_bytes = byte_type(
|
||||
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
||||
assert type_ == 100
|
||||
assert data == b"\x42" * 4
|
||||
|
||||
# Make sure we correctly handle fragments
|
||||
for i in range(len(in_bytes)):
|
||||
helper.data_received(in_bytes[i : i + 1])
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
||||
assert type_ == 100
|
||||
assert data == b"\x42" * 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"byte_type",
|
||||
(bytes, bytearray, memoryview),
|
||||
)
|
||||
async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
||||
"""Test the noise frame helper with the protractor event loop.
|
||||
|
||||
With the protractor event loop, data_received is called with a bytearray
|
||||
instead of bytes.
|
||||
|
||||
https://github.com/esphome/issues/issues/5117
|
||||
"""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
incoming_packets = [
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_incorrect_key():
|
||||
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
|
||||
@ -478,3 +564,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
|
||||
assert packets == [(42, b"from device")]
|
||||
helper.close()
|
||||
|
||||
with pytest.raises(ProtocolAPIError, match="Connection closed"):
|
||||
helper.data_received(encrypted_header + encrypted_payload)
|
||||
|
Loading…
Reference in New Issue
Block a user