mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-23 12:15:13 +01:00
Include the server name in the exception for invalid encryption key (#452)
This commit is contained in:
parent
cfb3972c9d
commit
59a66ba870
@ -52,6 +52,15 @@ ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
|
||||
class APIFrameHelper(asyncio.Protocol):
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
__slots__ = (
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_transport",
|
||||
"_connected_event",
|
||||
"_buffer",
|
||||
"_pos",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
@ -203,17 +212,17 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# If we have more data, continue processing
|
||||
|
||||
|
||||
def _decode_noise_psk(psk: str) -> bytes:
|
||||
def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
try:
|
||||
psk_bytes = base64.b64decode(psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected base64-encoded value"
|
||||
f"Malformed PSK {psk}, expected base64-encoded value", server_name
|
||||
)
|
||||
if len(psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
|
||||
f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name
|
||||
)
|
||||
return psk_bytes
|
||||
|
||||
@ -230,6 +239,15 @@ class NoiseConnectionState(Enum):
|
||||
class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for noise encrypted connections."""
|
||||
|
||||
__slots__ = (
|
||||
"_ready_future",
|
||||
"_noise_psk",
|
||||
"_expected_name",
|
||||
"_state",
|
||||
"_server_name",
|
||||
"_proto",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
@ -243,6 +261,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._state = NoiseConnectionState.HELLO
|
||||
self._server_name: Optional[str] = None
|
||||
self._setup_proto()
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
@ -343,6 +362,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if server_name_i != -1:
|
||||
# server name found, this extension was added in 2022.2
|
||||
server_name = server_hello[1:server_name_i].decode()
|
||||
self._server_name = server_name
|
||||
|
||||
if self._expected_name is not None and self._expected_name != server_name:
|
||||
self._handle_error_and_close(
|
||||
BadNameAPIError(
|
||||
@ -360,7 +381,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||
)
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name))
|
||||
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
|
||||
self._proto.start_handshake()
|
||||
|
||||
@ -374,7 +395,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
self._handle_error_and_close(
|
||||
InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
InvalidEncryptionKeyAPIError(
|
||||
"Invalid encryption key", self._server_name
|
||||
)
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
@ -384,7 +407,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
try:
|
||||
self._proto.read_message(msg[1:])
|
||||
except InvalidTag as invalid_tag_exc:
|
||||
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
ex = InvalidEncryptionKeyAPIError(
|
||||
"Invalid encryption key", self._server_name
|
||||
)
|
||||
ex.__cause__ = invalid_tag_exc
|
||||
self._handle_error_and_close(ex)
|
||||
return
|
||||
|
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from aioesphomeapi.model import BluetoothGATTError
|
||||
|
||||
@ -191,12 +192,16 @@ class BadNameAPIError(APIConnectionError):
|
||||
"""Raised when a name received from the remote but does not much the expected name."""
|
||||
|
||||
def __init__(self, msg: str, received_name: str) -> None:
|
||||
super().__init__(msg)
|
||||
super().__init__(f"{msg}: received_name={received_name}")
|
||||
self.received_name = received_name
|
||||
|
||||
|
||||
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||
pass
|
||||
def __init__(
|
||||
self, msg: Optional[str] = None, received_name: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(f"{msg}: received_name={received_name}")
|
||||
self.received_name = received_name
|
||||
|
||||
|
||||
class PingFailedAPIError(APIConnectionError):
|
||||
|
Loading…
Reference in New Issue
Block a user