Fix handling incoming data from protractor event loop (#642)

This commit is contained in:
J. Nick Koston 2023-11-16 17:50:54 -06:00 committed by GitHub
parent a60e54d438
commit f94ddf8e6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 49 deletions

View File

@ -23,7 +23,8 @@ cdef class APIFrameHelper:
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length)
cdef _add_to_buffer(self, bytes data)
@cython.locals(bytes_data=bytes)
cdef _add_to_buffer(self, object data)
@cython.locals(end_of_frame_pos="unsigned int")
cdef _remove_from_buffer(self)

View File

@ -66,21 +66,29 @@ class APIFrameHelper:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def _add_to_buffer(self, data: bytes) -> None:
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
"""Add data to the buffer."""
# This should not be isinstance(data, bytes) because we want to
# to explicitly check for bytes and not for subclasses of bytes
if type(data) is not bytes: # pylint: disable=unidiomatic-typecheck
# Protractor sends a bytearray, so we need to convert it to bytes
# https://github.com/esphome/issues/issues/5117
bytes_data = bytes(data)
else:
bytes_data = data
if self._buffer_len == 0:
# This is the best case scenario, we don't have to copy the data
# and can just use the buffer directly. This is the most common
# case as well.
self._buffer = data
self._buffer = bytes_data
else:
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
# This is the worst case scenario, we have to copy the data
# This is the worst case scenario, we have to copy the bytes_data
# and can't just use the buffer directly. This is also very
# uncommon since we usually read the entire frame at once.
self._buffer += data
self._buffer_len += len(data)
self._buffer += bytes_data
self._buffer_len += len(bytes_data)
def _remove_from_buffer(self) -> None:
"""Remove data from the buffer."""

View File

@ -6,17 +6,21 @@ from .base cimport APIFrameHelper
cdef bint TYPE_CHECKING
cdef unsigned int NOISE_STATE_HELLO
cdef unsigned int NOISE_STATE_HANDSHAKE
cdef unsigned int NOISE_STATE_READY
cdef unsigned int NOISE_STATE_CLOSED
cdef class APINoiseFrameHelper(APIFrameHelper):
cdef object _noise_psk
cdef object _expected_name
cdef object _state
cdef unsigned int _state
cdef object _dispatch
cdef object _server_name
cdef object _proto
cdef object _decrypt
cdef object _encrypt
cdef bint _is_ready
@cython.locals(
header=bytes,
@ -24,13 +28,19 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
msg_size_high=cython.uint,
msg_size_low=cython.uint,
)
cpdef data_received(self, bytes data)
cpdef data_received(self, object data)
@cython.locals(
type_high=cython.uint,
type_low=cython.uint
)
cpdef _handle_frame(self, bytes data)
cdef _handle_frame(self, bytes frame)
cdef _handle_hello(self, bytes server_hello)
cdef _handle_handshake(self, bytes msg)
cdef _handle_closed(self, bytes frame)
@cython.locals(
type_="unsigned int",

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import binascii
import logging
from enum import Enum
from functools import partial
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
@ -53,13 +52,17 @@ class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
class NoiseConnectionState(Enum):
"""Noise connection state."""
HELLO = 1
HANDSHAKE = 2
READY = 3
CLOSED = 4
# This is effectively an enum but we don't want to use an enum
# because we have a simple dispatch in the data_received method
# that would be more complicated with an enum and we want to add
# cdefs for each different state so we have a good test for each
# state receiving data since we found that the protractor event
# loop will send use a bytearray instead of bytes was not handled
# correctly.
NOISE_STATE_HELLO = 1
NOISE_STATE_HANDSHAKE = 2
NOISE_STATE_READY = 3
NOISE_STATE_CLOSED = 4
NOISE_HELLO = b"\x01\x00\x00"
@ -79,7 +82,6 @@ class APINoiseFrameHelper(APIFrameHelper):
"_proto",
"_decrypt",
"_encrypt",
"_is_ready",
)
def __init__(
@ -94,18 +96,11 @@ class APINoiseFrameHelper(APIFrameHelper):
super().__init__(connection, client_info, log_name)
self._noise_psk = noise_psk
self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO)
self._state = NOISE_STATE_HELLO
self._server_name: str | None = None
self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Callable[[bytes], bytes] | None = None
self._setup_proto()
self._is_ready = False
def _set_state(self, state: NoiseConnectionState) -> None:
"""Set the current state."""
self._state = state
self._is_ready = state == NoiseConnectionState.READY
self._dispatch = self.STATE_TO_CALLABLE[state]
def close(self) -> None:
"""Close the connection."""
@ -115,7 +110,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._set_ready_future_exception(
APIConnectionError(f"{self._log_name}: Connection closed")
)
self._set_state(NoiseConnectionState.CLOSED)
self._state = NOISE_STATE_CLOSED
super().close()
def _handle_error_and_close(self, exc: Exception) -> None:
@ -124,10 +119,7 @@ class APINoiseFrameHelper(APIFrameHelper):
def _handle_error(self, exc: Exception) -> None:
"""Handle an error, and provide a good message when during hello."""
if (
isinstance(exc, ConnectionResetError)
and self._state == NoiseConnectionState.HELLO
):
if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO:
original_exc = exc
exc = HandshakeAPIError(
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
@ -142,7 +134,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._send_hello_handshake()
await super().perform_handshake(timeout)
def data_received(self, data: bytes) -> None:
def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data)
while self._buffer:
self._pos = 0
@ -168,7 +160,14 @@ class APINoiseFrameHelper(APIFrameHelper):
return
try:
self._dispatch(self, frame)
if self._state == NOISE_STATE_READY:
self._handle_frame(frame)
elif self._state == NOISE_STATE_HELLO:
self._handle_hello(frame)
elif self._state == NOISE_STATE_HANDSHAKE:
self._handle_handshake(frame)
else:
self._handle_closed(frame)
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
@ -236,7 +235,7 @@ class APINoiseFrameHelper(APIFrameHelper):
)
return
self._set_state(NoiseConnectionState.HANDSHAKE)
self._state = NOISE_STATE_HANDSHAKE
def _decode_noise_psk(self) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
@ -294,7 +293,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._handle_error_and_close(ex)
return
_LOGGER.debug("Handshake complete")
self._set_state(NoiseConnectionState.READY)
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
@ -311,7 +310,7 @@ class APINoiseFrameHelper(APIFrameHelper):
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
if not self._is_ready:
if self._state != NOISE_STATE_READY:
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
if TYPE_CHECKING:
@ -371,10 +370,3 @@ class APINoiseFrameHelper(APIFrameHelper):
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
STATE_TO_CALLABLE = {
NoiseConnectionState.HELLO: _handle_hello,
NoiseConnectionState.HANDSHAKE: _handle_handshake,
NoiseConnectionState.READY: _handle_frame,
NoiseConnectionState.CLOSED: _handle_closed,
}

View File

@ -26,9 +26,9 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
length_high=cython.uint,
maybe_msg_type=cython.uint
)
cpdef data_received(self, bytes data)
cpdef data_received(self, object data)
cpdef _error_on_incorrect_preamble(self, object preamble)
cdef _error_on_incorrect_preamble(self, object preamble)
@cython.locals(
type_="unsigned int",

View File

@ -91,7 +91,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
) from err
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
self, data: bytes
self, data: bytes | bytearray | memoryview
) -> None:
self._add_to_buffer(data)
while self._buffer:

View File

@ -25,6 +25,7 @@ from aioesphomeapi.core import (
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
SocketAPIError,
)
@ -67,7 +68,6 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
) from err
@pytest.mark.asyncio
@pytest.mark.parametrize(
"in_bytes, pkt_data, pkt_type",
[
@ -115,7 +115,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
),
],
)
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
@ -141,6 +143,90 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
assert data == pkt_data
@pytest.mark.parametrize(
"byte_type",
(bytes, bytearray, memoryview),
)
def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
"""Test the plaintext frame helper with the protractor event loop.
With the protractor event loop, data_received is called with a bytearray
instead of bytes.
https://github.com/esphome/issues/issues/5117
"""
for _ in range(3):
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test"
)
in_bytes = byte_type(
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
)
helper.data_received(in_bytes)
pkt = packets.pop()
type_, data = pkt
assert type_ == 100
assert data == b"\x42" * 4
# Make sure we correctly handle fragments
for i in range(len(in_bytes)):
helper.data_received(in_bytes[i : i + 1])
pkt = packets.pop()
type_, data = pkt
assert type_ == 100
assert data == b"\x42" * 4
@pytest.mark.asyncio
@pytest.mark.parametrize(
"byte_type",
(bytes, bytearray, memoryview),
)
async def test_noise_protector_event_loop(byte_type: Any) -> None:
"""Test the noise frame helper with the protractor event loop.
With the protractor event loop, data_received is called with a bytearray
instead of bytes.
https://github.com/esphome/issues/issues/5117
"""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
helper.data_received(byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
@pytest.mark.asyncio
async def test_noise_frame_helper_incorrect_key():
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
@ -478,3 +564,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert packets == [(42, b"from device")]
helper.close()
with pytest.raises(ProtocolAPIError, match="Connection closed"):
helper.data_received(encrypted_header + encrypted_payload)