Speed up encrypted handshake (#487)

This commit is contained in:
J. Nick Koston 2023-07-17 14:51:47 -10:00 committed by GitHub
parent f83d3f4e6f
commit 7196ca6ee8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 62 deletions

View File

@ -36,6 +36,8 @@ SOCKET_ERRORS = (
PACK_NONCE = partial(Struct("<LQ").pack, 0)
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
"""ChaCha20 cipher that can be reused."""
@ -71,6 +73,7 @@ class APIFrameHelper(asyncio.Protocol):
"_pos",
"_client_info",
"_log_name",
"_debug_enabled",
)
def __init__(
@ -93,6 +96,7 @@ class APIFrameHelper(asyncio.Protocol):
self._pos = 0
self._client_info = client_info
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
@ -154,12 +158,12 @@ class APIPlaintextFrameHelper(APIFrameHelper):
assert self._writer is not None, "Writer should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
if _LOGGER.isEnabledFor(logging.DEBUG):
if self._debug_enabled():
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
try:
self._writer(data)
except (RuntimeError, ConnectionResetError, OSError) as err:
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
@ -267,6 +271,9 @@ class NoiseConnectionState(Enum):
CLOSED = 4
NOISE_HELLO = b"\x01\x00\x00"
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
@ -347,37 +354,9 @@ class APINoiseFrameHelper(APIFrameHelper):
exc.__cause__ = original_exc
super()._handle_error(exc)
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket.
The entire packet must be written in a single call to write.
"""
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
try:
self._writer(
bytes(
(
0x01,
(frame_len >> 8) & 0xFF,
frame_len & 0xFF,
)
)
+ frame
)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
async def perform_handshake(self) -> None:
"""Perform the handshake with the server."""
self._send_hello()
self._send_hello_handshake()
handshake_handle = self._loop.call_later(
60, self._set_ready_future_exception, asyncio.TimeoutError()
)
@ -424,9 +403,29 @@ class APINoiseFrameHelper(APIFrameHelper):
del self._buffer[:end_of_frame_pos]
self._buffer_len -= end_of_frame_pos
def _send_hello(self) -> None:
def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server."""
self._write_frame(b"") # ClientHello
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
handshake_frame = b"\x00" + self._proto.write_message()
frame_len = len(handshake_frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame
if self._debug_enabled():
_LOGGER.debug(
"%s: Sending encrypted hello handshake: [%s]",
self._log_name,
hello_handshake.hex(),
)
try:
self._writer(hello_handshake)
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server."""
@ -467,7 +466,6 @@ class APINoiseFrameHelper(APIFrameHelper):
return
self._set_state(NoiseConnectionState.HANDSHAKE)
self._send_handshake()
def _decode_noise_psk(self) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
@ -500,10 +498,6 @@ class APINoiseFrameHelper(APIFrameHelper):
proto.start_handshake()
self._proto = proto
def _send_handshake(self) -> None:
"""Send the handshake message."""
self._write_frame(b"\x00" + self._proto.write_message())
def _handle_handshake(self, msg: bytearray) -> None:
_LOGGER.debug("Starting handshake...")
if msg[0] != 0:
@ -548,21 +542,25 @@ class APINoiseFrameHelper(APIFrameHelper):
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
assert self._writer is not None, "Writer is not set"
data_len = len(data)
self._write_frame(
self._encrypt(
bytes(
[
(type_ >> 8) & 0xFF,
type_ & 0xFF,
(data_len >> 8) & 0xFF,
data_len & 0xFF,
]
)
+ data
)
type_len = bytes(
((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF)
)
frame = self._encrypt(type_len + data)
if self._debug_enabled():
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
def _handle_frame(self, frame: bytearray) -> None:
"""Handle an incoming frame."""

View File

@ -154,6 +154,7 @@ class APIConnection:
"is_connected",
"is_authenticated",
"_is_socket_open",
"_debug_enabled",
)
def __init__(
@ -195,6 +196,7 @@ class APIConnection:
self.is_connected = False
self.is_authenticated = False
self._is_socket_open = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
@property
def connection_state(self) -> ConnectionState:
@ -284,7 +286,7 @@ class APIConnection:
err,
)
if _LOGGER.isEnabledFor(logging.DEBUG):
if self._debug_enabled():
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
@ -545,7 +547,7 @@ class APIConnection:
raise ValueError(f"Message type id not found for type {_msg_type}")
encoded = msg.SerializeToString()
if _LOGGER.isEnabledFor(logging.DEBUG):
if self._debug_enabled():
_LOGGER.debug("%s: Sending %s: %s", self.log_name, _msg_type.__name__, msg)
try:
@ -715,7 +717,7 @@ class APIConnection:
def _process_packet_factory(self) -> Callable[[int, bytes], None]:
"""Factory to make a packet processor."""
message_type_to_proto = MESSAGE_TYPE_TO_PROTO
debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
debug_enabled = self._debug_enabled
message_handlers = self._message_handlers
internal_message_types = INTERNAL_MESSAGE_TYPES

View File

@ -1,15 +1,38 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
from aioesphomeapi._frame_helper import (
WRITE_EXCEPTIONS,
APINoiseFrameHelper,
APIPlaintextFrameHelper,
)
from aioesphomeapi.core import (
BadNameAPIError,
InvalidEncryptionKeyAPIError,
SocketAPIError,
)
from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00"
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def mock_write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket.
The entire packet must be written in a single call to write.
"""
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
@pytest.mark.asyncio
@pytest.mark.parametrize(
"in_bytes, pkt_data, pkt_type",
@ -110,7 +133,7 @@ async def test_noise_frame_helper_incorrect_key():
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -122,7 +145,7 @@ async def test_noise_frame_helper_incorrect_key():
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
@ -151,7 +174,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -163,7 +186,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
@ -194,7 +217,7 @@ async def test_noise_incorrect_name():
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -206,7 +229,7 @@ async def test_noise_incorrect_name():
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
for pkt in incoming_packets: