Speed up noise handshake (#486)

This commit is contained in:
J. Nick Koston 2023-07-17 14:13:58 -10:00 committed by GitHub
parent 0c1f710869
commit ab3c096c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 39 deletions

View File

@ -5,9 +5,8 @@ from abc import abstractmethod
from enum import Enum
from functools import partial
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
import async_timeout
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag
from noise.backends.default import DefaultNoiseBackend # type: ignore[import]
@ -65,6 +64,7 @@ class APIFrameHelper(asyncio.Protocol):
"_on_pkt",
"_on_error",
"_transport",
"_writer",
"_connected_event",
"_buffer",
"_buffer_len",
@ -84,6 +84,9 @@ class APIFrameHelper(asyncio.Protocol):
self._on_pkt = on_pkt
self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self._writer: Optional[
Callable[[Union[bytes, bytearray, memoryview]], None]
] = None
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._buffer_len = 0
@ -111,6 +114,7 @@ class APIFrameHelper(asyncio.Protocol):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write
self._connected_event.set()
def _handle_error_and_close(self, exc: Exception) -> None:
@ -134,6 +138,8 @@ class APIFrameHelper(asyncio.Protocol):
"""Close the connection."""
if self._transport:
self._transport.close()
self._transport = None
self._writer = None
class APIPlaintextFrameHelper(APIFrameHelper):
@ -144,13 +150,15 @@ class APIPlaintextFrameHelper(APIFrameHelper):
The entire packet must be written in a single call.
"""
assert self._transport is not None, "Transport should be set"
if TYPE_CHECKING:
assert self._writer is not None, "Writer should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
try:
self._transport.write(data)
self._writer(data)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
@ -250,21 +258,6 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# If we have more data, continue processing
def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
try:
psk_bytes = base64.b64decode(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected base64-encoded value", server_name
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name
)
return psk_bytes
class NoiseConnectionState(Enum):
"""Noise connection state."""
@ -287,6 +280,8 @@ class APINoiseFrameHelper(APIFrameHelper):
"_proto",
"_decrypt",
"_encrypt",
"_is_ready",
"_loop",
)
def __init__(
@ -300,7 +295,9 @@ class APINoiseFrameHelper(APIFrameHelper):
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error, client_info, log_name)
self._ready_future = asyncio.get_event_loop().create_future()
loop = asyncio.get_event_loop()
self._loop = loop
self._ready_future = loop.create_future()
self._noise_psk = noise_psk
self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO)
@ -308,6 +305,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._decrypt: Optional[Callable[[bytes], bytes]] = None
self._encrypt: Optional[Callable[[bytes], bytes]] = None
self._setup_proto()
self._is_ready = False
def _set_ready_future_exception(self, exc: Exception) -> None:
if not self._ready_future.done():
@ -316,6 +314,7 @@ class APINoiseFrameHelper(APIFrameHelper):
def _set_state(self, state: NoiseConnectionState) -> None:
"""Set the current state."""
self._state = state
self._is_ready = state == NoiseConnectionState.READY
self._dispatch = self.STATE_TO_CALLABLE[state]
def close(self) -> None:
@ -353,20 +352,24 @@ class APINoiseFrameHelper(APIFrameHelper):
The entire packet must be written in a single call to write.
"""
assert self._transport is not None, "Transport is not set"
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
try:
header = bytes(
[
0x01,
(frame_len >> 8) & 0xFF,
frame_len & 0xFF,
]
self._writer(
bytes(
(
0x01,
(frame_len >> 8) & 0xFF,
frame_len & 0xFF,
)
)
+ frame
)
self._transport.write(header + frame)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
@ -375,13 +378,17 @@ class APINoiseFrameHelper(APIFrameHelper):
async def perform_handshake(self) -> None:
"""Perform the handshake with the server."""
self._send_hello()
handshake_handle = self._loop.call_later(
60, self._set_ready_future_exception, asyncio.TimeoutError()
)
try:
async with async_timeout.timeout(60.0):
await self._ready_future
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
def data_received(self, data: bytes) -> None:
self._buffer += data
@ -462,15 +469,36 @@ class APINoiseFrameHelper(APIFrameHelper):
self._set_state(NoiseConnectionState.HANDSHAKE)
self._send_handshake()
def _decode_noise_psk(self) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
psk = self._noise_psk
server_name = self._server_name
try:
psk_bytes = base64.b64decode(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}: Malformed PSK {psk}, expected "
"base64-encoded value",
server_name,
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}:Malformed PSK {psk}, expected"
f" 32-bytes of base64 data",
server_name,
)
return psk_bytes
def _setup_proto(self) -> None:
"""Set up the noise protocol."""
self._proto = NoiseConnection.from_name(
proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
)
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name))
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
self._proto.start_handshake()
proto.set_as_initiator()
proto.set_psks(self._decode_noise_psk())
proto.set_prologue(b"NoiseAPIInit\x00\x00")
proto.start_handshake()
self._proto = proto
def _send_handshake(self) -> None:
"""Send the handshake message."""
@ -515,10 +543,12 @@ class APINoiseFrameHelper(APIFrameHelper):
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
if self._state != NoiseConnectionState.READY:
if not self._is_ready:
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
data_len = len(data)
self._write_frame(
self._encrypt(

View File

@ -715,8 +715,7 @@ class APIConnection:
def _process_packet_factory(self) -> Callable[[int, bytes], None]:
"""Factory to make a packet processor."""
message_type_to_proto = MESSAGE_TYPE_TO_PROTO
is_enabled_for = _LOGGER.isEnabledFor
logging_debug = logging.DEBUG
debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
message_handlers = self._message_handlers
internal_message_types = INTERNAL_MESSAGE_TYPES
@ -759,7 +758,7 @@ class APIConnection:
msg_type = type(msg)
if is_enabled_for(logging_debug):
if debug_enabled():
_LOGGER.debug(
"%s: Got message of type %s: %s",
self.log_name,

View File

@ -119,6 +119,7 @@ async def test_noise_frame_helper_incorrect_key():
log_name="test",
)
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
@ -159,6 +160,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
log_name="test",
)
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
@ -201,6 +203,7 @@ async def test_noise_incorrect_name():
log_name="test",
)
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))

View File

@ -61,6 +61,7 @@ def _get_mock_protocol(conn: APIConnection):
)
protocol._connected_event.set()
protocol._transport = MagicMock()
protocol._writer = MagicMock()
return protocol