from __future__ import annotations import asyncio import binascii import logging from functools import partial from struct import Struct from typing import TYPE_CHECKING, Any, Callable from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable from cryptography.exceptions import InvalidTag from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped] from noise.backends.default.ciphers import ( # type: ignore[import-untyped] ChaCha20Cipher, ) from noise.connection import NoiseConnection # type: ignore[import-untyped] from ..core import ( APIConnectionError, BadNameAPIError, HandshakeAPIError, InvalidEncryptionKeyAPIError, ProtocolAPIError, ) from .base import _LOGGER, APIFrameHelper if TYPE_CHECKING: from ..connection import APIConnection PACK_NONCE = partial(Struct(" type[ChaCha20Poly1305Reusable]: return ChaCha20Poly1305Reusable class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.ciphers["ChaChaPoly"] = ChaCha20CipherReuseable ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend() # This is effectively an enum but we don't want to use an enum # because we have a simple dispatch in the data_received method # that would be more complicated with an enum and we want to add # cdefs for each different state so we have a good test for each # state receiving data since we found that the protractor event # loop will send use a bytearray instead of bytes was not handled # correctly. NOISE_STATE_HELLO = 1 NOISE_STATE_HANDSHAKE = 2 NOISE_STATE_READY = 3 NOISE_STATE_CLOSED = 4 NOISE_HELLO = b"\x01\x00\x00" int_ = int class APINoiseFrameHelper(APIFrameHelper): """Frame helper for noise encrypted connections.""" __slots__ = ( "_noise_psk", "_expected_name", "_state", "_server_name", "_proto", "_decrypt", "_encrypt", ) def __init__( self, connection: APIConnection, noise_psk: str, expected_name: str | None, client_info: str, log_name: str, ) -> None: """Initialize the API frame helper.""" super().__init__(connection, client_info, log_name) self._noise_psk = noise_psk self._expected_name = expected_name self._state = NOISE_STATE_HELLO self._server_name: str | None = None self._decrypt: Callable[[bytes], bytes] | None = None self._encrypt: Callable[[bytes], bytes] | None = None self._setup_proto() def close(self) -> None: """Close the connection.""" # Make sure we set the ready event if its not already set # so that we don't block forever on the ready event if we # are waiting for the handshake to complete. self._set_ready_future_exception( APIConnectionError(f"{self._log_name}: Connection closed") ) self._state = NOISE_STATE_CLOSED super().close() def _handle_error(self, exc: Exception) -> None: """Handle an error, and provide a good message when during hello.""" if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError): original_exc: Exception = exc exc = HandshakeAPIError( f"{self._log_name}: The connection dropped immediately after encrypted hello; " "Try enabling encryption on the device or turning off " f"encryption on the client ({self._client_info})." ) exc.__cause__ = original_exc elif isinstance(exc, InvalidTag): original_exc = exc exc = InvalidEncryptionKeyAPIError( f"{self._log_name}: Invalid encryption key", self._server_name ) exc.__cause__ = original_exc super()._handle_error(exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a new connection.""" super().connection_made(transport) self._send_hello_handshake() def data_received(self, data: bytes | bytearray | memoryview) -> None: self._add_to_buffer(data) while self._buffer_len: self._pos = 0 if (header := self._read(3)) is None: return preamble = header[0] msg_size_high = header[1] msg_size_low = header[2] if preamble != 0x01: self._handle_error_and_close( ProtocolAPIError( f"{self._log_name}: Marker byte invalid: {header[0]}" ) ) return if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None: # The complete frame is not yet available, wait for more data # to arrive before continuing, since callback_packet has not # been called yet the buffer will not be cleared and the next # call to data_received will continue processing the packet # at the start of the frame. return # asyncio already runs data_received in a try block # which will call connection_lost if an exception is raised if self._state == NOISE_STATE_READY: self._handle_frame(frame) elif self._state == NOISE_STATE_HELLO: self._handle_hello(frame) elif self._state == NOISE_STATE_HANDSHAKE: self._handle_handshake(frame) else: self._handle_closed(frame) self._remove_from_buffer() def _send_hello_handshake(self) -> None: """Send a ClientHello to the server.""" handshake_frame = self._proto.write_message() frame_len = len(handshake_frame) + 1 header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) self._write_bytes( b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)), _LOGGER.isEnabledFor(logging.DEBUG), ) def _handle_hello(self, server_hello: bytes) -> None: """Perform the handshake with the server.""" if not server_hello: self._handle_error_and_close( HandshakeAPIError(f"{self._log_name}: ServerHello is empty") ) return # First byte of server hello is the protocol the server chose # for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256) # exists. chosen_proto = server_hello[0] if chosen_proto != 0x01: self._handle_error_and_close( HandshakeAPIError( f"{self._log_name}: Unknown protocol selected by client {chosen_proto}" ) ) return # Check name matches expected name (for noise sessions, this is done # during hello phase before a connection is set up) # Server name is encoded as a string followed by a zero byte after the chosen proto byte server_name_i = server_hello.find(b"\0", 1) if server_name_i != -1: # server name found, this extension was added in 2022.2 server_name = server_hello[1:server_name_i].decode() self._server_name = server_name if self._expected_name is not None and self._expected_name != server_name: self._handle_error_and_close( BadNameAPIError( f"{self._log_name}: Server sent a different name '{server_name}'", server_name, ) ) return self._state = NOISE_STATE_HANDSHAKE def _decode_noise_psk(self) -> bytes: """Decode the given noise psk from base64 format to raw bytes.""" psk = self._noise_psk server_name = self._server_name try: psk_bytes = binascii.a2b_base64(psk) except ValueError: raise InvalidEncryptionKeyAPIError( f"{self._log_name}: Malformed PSK `{psk}`, expected " "base64-encoded value", server_name, ) if len(psk_bytes) != 32: raise InvalidEncryptionKeyAPIError( f"{self._log_name}:Malformed PSK `{psk}`, expected" f" 32-bytes of base64 data", server_name, ) return psk_bytes def _setup_proto(self) -> None: """Set up the noise protocol.""" proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND ) proto.set_as_initiator() proto.set_psks(self._decode_noise_psk()) proto.set_prologue(b"NoiseAPIInit\x00\x00") proto.start_handshake() self._proto = proto def _error_on_incorrect_preamble(self, msg: bytes) -> None: """Handle an incorrect preamble.""" explanation = msg[1:].decode() if explanation != "Handshake MAC failure": exc = HandshakeAPIError( f"{self._log_name}: Handshake failure: {explanation}" ) else: exc = InvalidEncryptionKeyAPIError( f"{self._log_name}: Invalid encryption key", self._server_name ) self._handle_error_and_close(exc) def _handle_handshake(self, msg: bytes) -> None: if msg[0] != 0: self._error_on_incorrect_preamble(msg) return self._proto.read_message(msg[1:]) self._state = NOISE_STATE_READY noise_protocol = self._proto.noise_protocol self._decrypt = partial( noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member None, ) self._encrypt = partial( noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member None, ) self._ready_future.set_result(None) def write_packets( self, packets: list[tuple[int, bytes]], debug_enabled: bool ) -> None: """Write a packets to the socket. Packets are in the format of tuple[protobuf_type, protobuf_data] """ if TYPE_CHECKING: assert self._encrypt is not None, "Handshake should be complete" out: list[bytes] = [] for packet in packets: type_: int = packet[0] data: bytes = packet[1] data_len = len(data) data_header = bytes( ( (type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF, ) ) frame = self._encrypt(data_header + data) frame_len = len(frame) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) out.append(header) out.append(frame) self._write_bytes(b"".join(out), debug_enabled) def _handle_frame(self, frame: bytes) -> None: """Handle an incoming frame.""" if TYPE_CHECKING: assert self._decrypt is not None, "Handshake should be complete" msg = self._decrypt(frame) # Message layout is # 2 bytes: message type # 2 bytes: message length # N bytes: message data type_high = msg[0] type_low = msg[1] self._connection.process_packet((type_high << 8) | type_low, msg[4:]) def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument """Handle a closed frame.""" self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))