Include the server name in the exception for invalid encryption key (#452)

This commit is contained in:
J. Nick Koston 2023-07-01 11:12:38 -05:00 committed by GitHub
parent cfb3972c9d
commit 59a66ba870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 8 deletions

View File

@ -52,6 +52,15 @@ ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol."""
__slots__ = (
"_on_pkt",
"_on_error",
"_transport",
"_connected_event",
"_buffer",
"_pos",
)
def __init__(
self,
on_pkt: Callable[[int, bytes], None],
@ -203,17 +212,17 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# If we have more data, continue processing
def _decode_noise_psk(psk: str) -> bytes:
def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
try:
psk_bytes = base64.b64decode(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected base64-encoded value"
f"Malformed PSK {psk}, expected base64-encoded value", server_name
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name
)
return psk_bytes
@ -230,6 +239,15 @@ class NoiseConnectionState(Enum):
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
__slots__ = (
"_ready_future",
"_noise_psk",
"_expected_name",
"_state",
"_server_name",
"_proto",
)
def __init__(
self,
on_pkt: Callable[[int, bytes], None],
@ -243,6 +261,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._server_name: Optional[str] = None
self._setup_proto()
def _set_ready_future_exception(self, exc: Exception) -> None:
@ -343,6 +362,8 @@ class APINoiseFrameHelper(APIFrameHelper):
if server_name_i != -1:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
self._server_name = server_name
if self._expected_name is not None and self._expected_name != server_name:
self._handle_error_and_close(
BadNameAPIError(
@ -360,7 +381,7 @@ class APINoiseFrameHelper(APIFrameHelper):
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
)
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name))
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
self._proto.start_handshake()
@ -374,7 +395,9 @@ class APINoiseFrameHelper(APIFrameHelper):
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
self._handle_error_and_close(
InvalidEncryptionKeyAPIError("Invalid encryption key")
InvalidEncryptionKeyAPIError(
"Invalid encryption key", self._server_name
)
)
return
self._handle_error_and_close(
@ -384,7 +407,9 @@ class APINoiseFrameHelper(APIFrameHelper):
try:
self._proto.read_message(msg[1:])
except InvalidTag as invalid_tag_exc:
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
ex = InvalidEncryptionKeyAPIError(
"Invalid encryption key", self._server_name
)
ex.__cause__ = invalid_tag_exc
self._handle_error_and_close(ex)
return

View File

@ -1,4 +1,5 @@
import re
from typing import Optional
from aioesphomeapi.model import BluetoothGATTError
@ -191,12 +192,16 @@ class BadNameAPIError(APIConnectionError):
"""Raised when a name received from the remote but does not much the expected name."""
def __init__(self, msg: str, received_name: str) -> None:
super().__init__(msg)
super().__init__(f"{msg}: received_name={received_name}")
self.received_name = received_name
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
pass
def __init__(
self, msg: Optional[str] = None, received_name: Optional[str] = None
) -> None:
super().__init__(f"{msg}: received_name={received_name}")
self.received_name = received_name
class PingFailedAPIError(APIConnectionError):