mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Speed up decrypting frames (#944)
This commit is contained in:
parent
0969e9339d
commit
7e7ece4ca1
@ -12,6 +12,21 @@ cdef unsigned int NOISE_STATE_READY
|
||||
cdef unsigned int NOISE_STATE_CLOSED
|
||||
|
||||
cdef bytes NOISE_HELLO
|
||||
cdef object PACK_NONCE
|
||||
|
||||
cdef class EncryptCipher:
|
||||
|
||||
cdef object _nonce
|
||||
cdef object _encrypt
|
||||
|
||||
cdef bytes encrypt(self, object frame)
|
||||
|
||||
cdef class DecryptCipher:
|
||||
|
||||
cdef object _nonce
|
||||
cdef object _decrypt
|
||||
|
||||
cdef bytes decrypt(self, object frame)
|
||||
|
||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
@ -20,8 +35,8 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
cdef unsigned int _state
|
||||
cdef object _server_name
|
||||
cdef object _proto
|
||||
cdef object _decrypt
|
||||
cdef object _encrypt
|
||||
cdef EncryptCipher _encrypt_cipher
|
||||
cdef DecryptCipher _decrypt_cipher
|
||||
|
||||
@cython.locals(
|
||||
header=bytes,
|
||||
@ -59,6 +74,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
data=bytes,
|
||||
data_header=bytes,
|
||||
packet=tuple,
|
||||
data_len=cython.uint,
|
||||
frame=bytes,
|
||||
|
@ -5,15 +5,17 @@ import binascii
|
||||
from functools import partial
|
||||
import logging
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||
from cryptography.exceptions import InvalidTag
|
||||
from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped]
|
||||
from noise.backends.default.ciphers import ( # type: ignore[import-untyped]
|
||||
ChaCha20Cipher,
|
||||
CryptographyCipher,
|
||||
)
|
||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
from noise.state import CipherState # type: ignore[import-untyped]
|
||||
|
||||
from ..core import (
|
||||
APIConnectionError,
|
||||
@ -30,6 +32,8 @@ if TYPE_CHECKING:
|
||||
|
||||
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
||||
|
||||
_bytes = bytes
|
||||
|
||||
|
||||
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
|
||||
"""ChaCha20 cipher that can be reused."""
|
||||
@ -68,6 +72,44 @@ NOISE_HELLO = b"\x01\x00\x00"
|
||||
int_ = int
|
||||
|
||||
|
||||
class EncryptCipher:
|
||||
"""Wrapper around the ChaCha20Poly1305 cipher for encryption."""
|
||||
|
||||
__slots__ = ("_nonce", "_encrypt")
|
||||
|
||||
def __init__(self, cipher_state: CipherState) -> None:
|
||||
"""Initialize the cipher wrapper."""
|
||||
crypto_cipher: CryptographyCipher = cipher_state.cipher
|
||||
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
|
||||
self._nonce: int = cipher_state.n
|
||||
self._encrypt = cipher.encrypt
|
||||
|
||||
def encrypt(self, data: _bytes) -> bytes:
|
||||
"""Encrypt a frame."""
|
||||
ciphertext = self._encrypt(PACK_NONCE(self._nonce), data, None)
|
||||
self._nonce += 1
|
||||
return ciphertext
|
||||
|
||||
|
||||
class DecryptCipher:
|
||||
"""Wrapper around the ChaCha20Poly1305 cipher for decryption."""
|
||||
|
||||
__slots__ = ("_nonce", "_decrypt")
|
||||
|
||||
def __init__(self, cipher_state: CipherState) -> None:
|
||||
"""Initialize the cipher wrapper."""
|
||||
crypto_cipher: CryptographyCipher = cipher_state.cipher
|
||||
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
|
||||
self._nonce: int = cipher_state.n
|
||||
self._decrypt = cipher.decrypt
|
||||
|
||||
def decrypt(self, data: _bytes) -> bytes:
|
||||
"""Decrypt a frame."""
|
||||
plaintext = self._decrypt(PACK_NONCE(self._nonce), data, None)
|
||||
self._nonce += 1
|
||||
return plaintext
|
||||
|
||||
|
||||
class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for noise encrypted connections."""
|
||||
|
||||
@ -77,8 +119,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"_state",
|
||||
"_server_name",
|
||||
"_proto",
|
||||
"_decrypt",
|
||||
"_encrypt",
|
||||
"_encrypt_cipher",
|
||||
"_decrypt_cipher",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -95,8 +137,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._expected_name = expected_name
|
||||
self._state = NOISE_STATE_HELLO
|
||||
self._server_name: str | None = None
|
||||
self._decrypt: Callable[[bytes], bytes] | None = None
|
||||
self._encrypt: Callable[[bytes], bytes] | None = None
|
||||
self._encrypt_cipher: EncryptCipher | None = None
|
||||
self._decrypt_cipher: DecryptCipher | None = None
|
||||
self._setup_proto()
|
||||
|
||||
def close(self) -> None:
|
||||
@ -271,14 +313,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._proto.read_message(msg[1:])
|
||||
self._state = NOISE_STATE_READY
|
||||
noise_protocol = self._proto.noise_protocol
|
||||
self._decrypt = partial(
|
||||
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._encrypt = partial(
|
||||
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
|
||||
self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member
|
||||
self.ready_future.set_result(None)
|
||||
|
||||
def write_packets(
|
||||
@ -289,7 +325,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
assert self._encrypt_cipher is not None, "Handshake should be complete"
|
||||
|
||||
out: list[bytes] = []
|
||||
for packet in packets:
|
||||
@ -304,7 +340,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
data_len & 0xFF,
|
||||
)
|
||||
)
|
||||
frame = self._encrypt(data_header + data)
|
||||
frame = self._encrypt_cipher.encrypt(data_header + data)
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
out.append(header)
|
||||
@ -315,8 +351,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
def _handle_frame(self, frame: bytes) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._decrypt is not None, "Handshake should be complete"
|
||||
msg = self._decrypt(frame)
|
||||
assert self._decrypt_cipher is not None, "Handshake should be complete"
|
||||
msg = self._decrypt_cipher.decrypt(frame)
|
||||
# Message layout is
|
||||
# 2 bytes: message type
|
||||
# 2 bytes: message length
|
||||
|
Loading…
Reference in New Issue
Block a user