Speed up encrypted handshake (#487)
This commit is contained in:
parent
f83d3f4e6f
commit
7196ca6ee8
|
@ -36,6 +36,8 @@ SOCKET_ERRORS = (
|
|||
|
||||
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
||||
|
||||
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
||||
|
||||
|
||||
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
|
||||
"""ChaCha20 cipher that can be reused."""
|
||||
|
@ -71,6 +73,7 @@ class APIFrameHelper(asyncio.Protocol):
|
|||
"_pos",
|
||||
"_client_info",
|
||||
"_log_name",
|
||||
"_debug_enabled",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -93,6 +96,7 @@ class APIFrameHelper(asyncio.Protocol):
|
|||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
|
@ -154,12 +158,12 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
|
||||
try:
|
||||
self._writer(data)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
@ -267,6 +271,9 @@ class NoiseConnectionState(Enum):
|
|||
CLOSED = 4
|
||||
|
||||
|
||||
NOISE_HELLO = b"\x01\x00\x00"
|
||||
|
||||
|
||||
class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for noise encrypted connections."""
|
||||
|
||||
|
@ -347,37 +354,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
exc.__cause__ = original_exc
|
||||
super()._handle_error(exc)
|
||||
|
||||
def _write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call to write.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
try:
|
||||
self._writer(
|
||||
bytes(
|
||||
(
|
||||
0x01,
|
||||
(frame_len >> 8) & 0xFF,
|
||||
frame_len & 0xFF,
|
||||
)
|
||||
)
|
||||
+ frame
|
||||
)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
self._send_hello()
|
||||
self._send_hello_handshake()
|
||||
handshake_handle = self._loop.call_later(
|
||||
60, self._set_ready_future_exception, asyncio.TimeoutError()
|
||||
)
|
||||
|
@ -424,9 +403,29 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
|
||||
def _send_hello(self) -> None:
|
||||
def _send_hello_handshake(self) -> None:
|
||||
"""Send a ClientHello to the server."""
|
||||
self._write_frame(b"") # ClientHello
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
handshake_frame = b"\x00" + self._proto.write_message()
|
||||
frame_len = len(handshake_frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
||||
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug(
|
||||
"%s: Sending encrypted hello handshake: [%s]",
|
||||
self._log_name,
|
||||
hello_handshake.hex(),
|
||||
)
|
||||
|
||||
try:
|
||||
self._writer(hello_handshake)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
|
@ -467,7 +466,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
return
|
||||
|
||||
self._set_state(NoiseConnectionState.HANDSHAKE)
|
||||
self._send_handshake()
|
||||
|
||||
def _decode_noise_psk(self) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
|
@ -500,10 +498,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
proto.start_handshake()
|
||||
self._proto = proto
|
||||
|
||||
def _send_handshake(self) -> None:
|
||||
"""Send the handshake message."""
|
||||
self._write_frame(b"\x00" + self._proto.write_message())
|
||||
|
||||
def _handle_handshake(self, msg: bytearray) -> None:
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
if msg[0] != 0:
|
||||
|
@ -548,21 +542,25 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
data_len = len(data)
|
||||
self._write_frame(
|
||||
self._encrypt(
|
||||
bytes(
|
||||
[
|
||||
(type_ >> 8) & 0xFF,
|
||||
type_ & 0xFF,
|
||||
(data_len >> 8) & 0xFF,
|
||||
data_len & 0xFF,
|
||||
]
|
||||
)
|
||||
+ data
|
||||
)
|
||||
type_len = bytes(
|
||||
((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF)
|
||||
)
|
||||
frame = self._encrypt(type_len + data)
|
||||
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
try:
|
||||
self._writer(header + frame)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
def _handle_frame(self, frame: bytearray) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
|
|
|
@ -154,6 +154,7 @@ class APIConnection:
|
|||
"is_connected",
|
||||
"is_authenticated",
|
||||
"_is_socket_open",
|
||||
"_debug_enabled",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -195,6 +196,7 @@ class APIConnection:
|
|||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self._is_socket_open = False
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
@property
|
||||
def connection_state(self) -> ConnectionState:
|
||||
|
@ -284,7 +286,7 @@ class APIConnection:
|
|||
err,
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
|
@ -545,7 +547,7 @@ class APIConnection:
|
|||
raise ValueError(f"Message type id not found for type {_msg_type}")
|
||||
encoded = msg.SerializeToString()
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, _msg_type.__name__, msg)
|
||||
|
||||
try:
|
||||
|
@ -715,7 +717,7 @@ class APIConnection:
|
|||
def _process_packet_factory(self) -> Callable[[int, bytes], None]:
|
||||
"""Factory to make a packet processor."""
|
||||
message_type_to_proto = MESSAGE_TYPE_TO_PROTO
|
||||
debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
debug_enabled = self._debug_enabled
|
||||
message_handlers = self._message_handlers
|
||||
internal_message_types = INTERNAL_MESSAGE_TYPES
|
||||
|
||||
|
|
|
@ -1,15 +1,38 @@
|
|||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
|
||||
from aioesphomeapi._frame_helper import (
|
||||
WRITE_EXCEPTIONS,
|
||||
APINoiseFrameHelper,
|
||||
APIPlaintextFrameHelper,
|
||||
)
|
||||
from aioesphomeapi.core import (
|
||||
BadNameAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
from aioesphomeapi.util import varuint_to_bytes
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
|
||||
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
def mock_write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call to write.
|
||||
"""
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
try:
|
||||
self._writer(header + frame)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"in_bytes, pkt_data, pkt_type",
|
||||
|
@ -110,7 +133,7 @@ async def test_noise_frame_helper_incorrect_key():
|
|||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
|
@ -122,7 +145,7 @@ async def test_noise_frame_helper_incorrect_key():
|
|||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
|
@ -151,7 +174,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
|
@ -163,7 +186,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
|
@ -194,7 +217,7 @@ async def test_noise_incorrect_name():
|
|||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
|
@ -206,7 +229,7 @@ async def test_noise_incorrect_name():
|
|||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
for pkt in incoming_packets:
|
||||
|
|
Loading…
Reference in New Issue