diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 5a8f1f6..9b64540 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -52,6 +52,15 @@ ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend() class APIFrameHelper(asyncio.Protocol): """Helper class to handle the API frame protocol.""" + __slots__ = ( + "_on_pkt", + "_on_error", + "_transport", + "_connected_event", + "_buffer", + "_pos", + ) + def __init__( self, on_pkt: Callable[[int, bytes], None], @@ -203,17 +212,17 @@ class APIPlaintextFrameHelper(APIFrameHelper): # If we have more data, continue processing -def _decode_noise_psk(psk: str) -> bytes: +def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes: """Decode the given noise psk from base64 format to raw bytes.""" try: psk_bytes = base64.b64decode(psk) except ValueError: raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {psk}, expected base64-encoded value" + f"Malformed PSK {psk}, expected base64-encoded value", server_name ) if len(psk_bytes) != 32: raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {psk}, expected 32-bytes of base64 data" + f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name ) return psk_bytes @@ -230,6 +239,15 @@ class NoiseConnectionState(Enum): class APINoiseFrameHelper(APIFrameHelper): """Frame helper for noise encrypted connections.""" + __slots__ = ( + "_ready_future", + "_noise_psk", + "_expected_name", + "_state", + "_server_name", + "_proto", + ) + def __init__( self, on_pkt: Callable[[int, bytes], None], @@ -243,6 +261,7 @@ class APINoiseFrameHelper(APIFrameHelper): self._noise_psk = noise_psk self._expected_name = expected_name self._state = NoiseConnectionState.HELLO + self._server_name: Optional[str] = None self._setup_proto() def _set_ready_future_exception(self, exc: Exception) -> None: @@ -343,6 +362,8 @@ class APINoiseFrameHelper(APIFrameHelper): if server_name_i != -1: # server name found, this extension was added in 2022.2 server_name = server_hello[1:server_name_i].decode() + self._server_name = server_name + if self._expected_name is not None and self._expected_name != server_name: self._handle_error_and_close( BadNameAPIError( @@ -360,7 +381,7 @@ class APINoiseFrameHelper(APIFrameHelper): b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND ) self._proto.set_as_initiator() - self._proto.set_psks(_decode_noise_psk(self._noise_psk)) + self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name)) self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00") self._proto.start_handshake() @@ -374,7 +395,9 @@ class APINoiseFrameHelper(APIFrameHelper): explanation = msg[1:].decode() if explanation == "Handshake MAC failure": self._handle_error_and_close( - InvalidEncryptionKeyAPIError("Invalid encryption key") + InvalidEncryptionKeyAPIError( + "Invalid encryption key", self._server_name + ) ) return self._handle_error_and_close( @@ -384,7 +407,9 @@ class APINoiseFrameHelper(APIFrameHelper): try: self._proto.read_message(msg[1:]) except InvalidTag as invalid_tag_exc: - ex = InvalidEncryptionKeyAPIError("Invalid encryption key") + ex = InvalidEncryptionKeyAPIError( + "Invalid encryption key", self._server_name + ) ex.__cause__ = invalid_tag_exc self._handle_error_and_close(ex) return diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index 2f3bb0a..f2acc6c 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -1,4 +1,5 @@ import re +from typing import Optional from aioesphomeapi.model import BluetoothGATTError @@ -191,12 +192,16 @@ class BadNameAPIError(APIConnectionError): """Raised when a name received from the remote but does not much the expected name.""" def __init__(self, msg: str, received_name: str) -> None: - super().__init__(msg) + super().__init__(f"{msg}: received_name={received_name}") self.received_name = received_name class InvalidEncryptionKeyAPIError(HandshakeAPIError): - pass + def __init__( + self, msg: Optional[str] = None, received_name: Optional[str] = None + ) -> None: + super().__init__(f"{msg}: received_name={received_name}") + self.received_name = received_name class PingFailedAPIError(APIConnectionError):