From 7e7ece4ca1ee4c91e068554432cff7b4a8b9cf0c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 3 Sep 2024 11:43:00 -1000 Subject: [PATCH] Speed up decrypting frames (#944) --- aioesphomeapi/_frame_helper/noise.pxd | 20 +++++++- aioesphomeapi/_frame_helper/noise.py | 70 ++++++++++++++++++++------- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index c264eb3..f978344 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -12,6 +12,21 @@ cdef unsigned int NOISE_STATE_READY cdef unsigned int NOISE_STATE_CLOSED cdef bytes NOISE_HELLO +cdef object PACK_NONCE + +cdef class EncryptCipher: + + cdef object _nonce + cdef object _encrypt + + cdef bytes encrypt(self, object frame) + +cdef class DecryptCipher: + + cdef object _nonce + cdef object _decrypt + + cdef bytes decrypt(self, object frame) cdef class APINoiseFrameHelper(APIFrameHelper): @@ -20,8 +35,8 @@ cdef class APINoiseFrameHelper(APIFrameHelper): cdef unsigned int _state cdef object _server_name cdef object _proto - cdef object _decrypt - cdef object _encrypt + cdef EncryptCipher _encrypt_cipher + cdef DecryptCipher _decrypt_cipher @cython.locals( header=bytes, @@ -59,6 +74,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper): @cython.locals( type_="unsigned int", data=bytes, + data_header=bytes, packet=tuple, data_len=cython.uint, frame=bytes, diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index d72a6a6..ee9ed1a 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -5,15 +5,17 @@ import binascii from functools import partial import logging from struct import Struct -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable from cryptography.exceptions import InvalidTag from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped] from noise.backends.default.ciphers import ( # type: ignore[import-untyped] ChaCha20Cipher, + CryptographyCipher, ) from noise.connection import NoiseConnection # type: ignore[import-untyped] +from noise.state import CipherState # type: ignore[import-untyped] from ..core import ( APIConnectionError, @@ -30,6 +32,8 @@ if TYPE_CHECKING: PACK_NONCE = partial(Struct(" None: + """Initialize the cipher wrapper.""" + crypto_cipher: CryptographyCipher = cipher_state.cipher + cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher + self._nonce: int = cipher_state.n + self._encrypt = cipher.encrypt + + def encrypt(self, data: _bytes) -> bytes: + """Encrypt a frame.""" + ciphertext = self._encrypt(PACK_NONCE(self._nonce), data, None) + self._nonce += 1 + return ciphertext + + +class DecryptCipher: + """Wrapper around the ChaCha20Poly1305 cipher for decryption.""" + + __slots__ = ("_nonce", "_decrypt") + + def __init__(self, cipher_state: CipherState) -> None: + """Initialize the cipher wrapper.""" + crypto_cipher: CryptographyCipher = cipher_state.cipher + cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher + self._nonce: int = cipher_state.n + self._decrypt = cipher.decrypt + + def decrypt(self, data: _bytes) -> bytes: + """Decrypt a frame.""" + plaintext = self._decrypt(PACK_NONCE(self._nonce), data, None) + self._nonce += 1 + return plaintext + + class APINoiseFrameHelper(APIFrameHelper): """Frame helper for noise encrypted connections.""" @@ -77,8 +119,8 @@ class APINoiseFrameHelper(APIFrameHelper): "_state", "_server_name", "_proto", - "_decrypt", - "_encrypt", + "_encrypt_cipher", + "_decrypt_cipher", ) def __init__( @@ -95,8 +137,8 @@ class APINoiseFrameHelper(APIFrameHelper): self._expected_name = expected_name self._state = NOISE_STATE_HELLO self._server_name: str | None = None - self._decrypt: Callable[[bytes], bytes] | None = None - self._encrypt: Callable[[bytes], bytes] | None = None + self._encrypt_cipher: EncryptCipher | None = None + self._decrypt_cipher: DecryptCipher | None = None self._setup_proto() def close(self) -> None: @@ -271,14 +313,8 @@ class APINoiseFrameHelper(APIFrameHelper): self._proto.read_message(msg[1:]) self._state = NOISE_STATE_READY noise_protocol = self._proto.noise_protocol - self._decrypt = partial( - noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member - None, - ) - self._encrypt = partial( - noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member - None, - ) + self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member + self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member self.ready_future.set_result(None) def write_packets( @@ -289,7 +325,7 @@ class APINoiseFrameHelper(APIFrameHelper): Packets are in the format of tuple[protobuf_type, protobuf_data] """ if TYPE_CHECKING: - assert self._encrypt is not None, "Handshake should be complete" + assert self._encrypt_cipher is not None, "Handshake should be complete" out: list[bytes] = [] for packet in packets: @@ -304,7 +340,7 @@ class APINoiseFrameHelper(APIFrameHelper): data_len & 0xFF, ) ) - frame = self._encrypt(data_header + data) + frame = self._encrypt_cipher.encrypt(data_header + data) frame_len = len(frame) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) out.append(header) @@ -315,8 +351,8 @@ class APINoiseFrameHelper(APIFrameHelper): def _handle_frame(self, frame: bytes) -> None: """Handle an incoming frame.""" if TYPE_CHECKING: - assert self._decrypt is not None, "Handshake should be complete" - msg = self._decrypt(frame) + assert self._decrypt_cipher is not None, "Handshake should be complete" + msg = self._decrypt_cipher.decrypt(frame) # Message layout is # 2 bytes: message type # 2 bytes: message length