Fix more cases where incorrect encryption keys were not detected (#447)
This commit is contained in:
parent
c1752bcf49
commit
eaa5e295cf
|
@ -6,9 +6,11 @@ from enum import Enum
|
|||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
import async_timeout
|
||||
from cryptography.exceptions import InvalidTag
|
||||
from noise.connection import NoiseConnection # type: ignore
|
||||
|
||||
from .core import (
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
|
@ -219,29 +221,38 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error)
|
||||
self._ready_event = asyncio.Event()
|
||||
self._ready_future = asyncio.get_event_loop().create_future()
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._state = NoiseConnectionState.HELLO
|
||||
self._setup_proto()
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
# Make sure we set the ready event if its not already set
|
||||
# so that we don't block forever on the ready event if we
|
||||
# are waiting for the handshake to complete.
|
||||
self._ready_event.set()
|
||||
self._set_ready_future_exception(APIConnectionError("Connection closed"))
|
||||
self._state = NoiseConnectionState.CLOSED
|
||||
super().close()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
self._set_ready_future_exception(exc)
|
||||
super()._handle_error_and_close(exc)
|
||||
|
||||
def _write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
"""
|
||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||
assert self._transport is not None, "Transport is not set"
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
||||
|
||||
try:
|
||||
header = bytes(
|
||||
|
@ -260,7 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
self._send_hello()
|
||||
try:
|
||||
async with async_timeout.timeout(60.0):
|
||||
await self._ready_event.wait()
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError("Timeout during handshake") from err
|
||||
|
||||
|
@ -273,8 +284,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
)
|
||||
return
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = self._read_exactly(msg_size)
|
||||
|
||||
if frame is None:
|
||||
return
|
||||
|
||||
|
@ -292,16 +305,18 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||
if not server_hello:
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
||||
return
|
||||
|
||||
# First byte of server hello is the protocol the server chose
|
||||
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
||||
# exists.
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise HandshakeAPIError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
|
||||
)
|
||||
return
|
||||
|
||||
# Check name matches expected name (for noise sessions, this is done
|
||||
# during hello phase before a connection is set up)
|
||||
|
@ -311,9 +326,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
# server name found, this extension was added in 2022.2
|
||||
server_name = server_hello[1:server_name_i].decode()
|
||||
if self._expected_name is not None and self._expected_name != server_name:
|
||||
raise BadNameAPIError(
|
||||
f"Server sent a different name '{server_name}'", server_name
|
||||
self._handle_error_and_close(
|
||||
BadNameAPIError(
|
||||
f"Server sent a different name '{server_name}'", server_name
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self._state = NoiseConnectionState.HANDSHAKE
|
||||
self._send_handshake()
|
||||
|
@ -335,12 +353,24 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
self._proto.read_message(msg[1:])
|
||||
self._handle_error_and_close(
|
||||
InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._proto.read_message(msg[1:])
|
||||
except InvalidTag as invalid_tag_exc:
|
||||
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
ex.__cause__ = invalid_tag_exc
|
||||
self._handle_error_and_close(ex)
|
||||
return
|
||||
_LOGGER.debug("Handshake complete")
|
||||
self._state = NoiseConnectionState.READY
|
||||
self._ready_event.set()
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
|
@ -367,13 +397,17 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(bytes(frame))
|
||||
if len(msg) < 4:
|
||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
|
||||
return
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
)
|
||||
return
|
||||
data = msg[4 : 4 + data_len]
|
||||
return self._on_pkt(pkt_type, data)
|
||||
self._on_pkt(pkt_type, data)
|
||||
|
||||
def _handle_closed( # pylint: disable=unused-argument
|
||||
self, frame: bytearray
|
||||
|
|
|
@ -30,6 +30,7 @@ from .core import (
|
|||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
ConnectionNotEstablishedAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
PingFailedAPIError,
|
||||
|
@ -432,7 +433,7 @@ class APIConnection:
|
|||
self._cleanup()
|
||||
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occured during connect
|
||||
# Always clean up the connection if an error occurred during connect
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._cleanup()
|
||||
raise
|
||||
|
@ -493,7 +494,12 @@ class APIConnection:
|
|||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._is_socket_open:
|
||||
raise APIConnectionError(
|
||||
if in_do_connect.get(False):
|
||||
# If we are in the do_connect task, we can't raise an error
|
||||
# because it would obscure the original exception (ie encrypt error).
|
||||
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
|
||||
return
|
||||
raise ConnectionNotEstablishedAPIError(
|
||||
f"Connection isn't established yet ({self._connection_state})"
|
||||
)
|
||||
|
||||
|
|
|
@ -183,6 +183,10 @@ class HandshakeAPIError(APIConnectionError):
|
|||
pass
|
||||
|
||||
|
||||
class ConnectionNotEstablishedAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class BadNameAPIError(APIConnectionError):
|
||||
"""Raised when a name received from the remote but does not much the expected name."""
|
||||
|
||||
|
|
|
@ -5,11 +5,17 @@ from typing import Awaitable, Callable, List, Optional
|
|||
import zeroconf
|
||||
|
||||
from .client import APIClient
|
||||
from .core import APIConnectionError
|
||||
from .core import (
|
||||
APIConnectionError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
EXPECTED_DISCONNECT_COOLDOWN = 3.0
|
||||
MAXIMUM_BACKOFF_TRIES = 100
|
||||
|
||||
|
||||
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
@ -103,13 +109,26 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
level = logging.WARNING if self._tries == 0 else logging.DEBUG
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s",
|
||||
"Can't connect to ESPHome API for %s: %s (%s)",
|
||||
self._log_name,
|
||||
err,
|
||||
type(err).__name__,
|
||||
# Print stacktrace if unhandled (not APIConnectionError)
|
||||
exc_info=not isinstance(err, APIConnectionError),
|
||||
)
|
||||
self._tries += 1
|
||||
if isinstance(
|
||||
err,
|
||||
(
|
||||
RequiresEncryptionAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
InvalidAuthAPIError,
|
||||
),
|
||||
):
|
||||
# If we get an encryption or password error,
|
||||
# backoff for the maximum amount of time
|
||||
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||
else:
|
||||
self._tries += 1
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
self._connected = True
|
||||
|
|
|
@ -3,7 +3,8 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
|
||||
from aioesphomeapi.util import varuint_to_bytes
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
@ -63,3 +64,79 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
|||
|
||||
assert type_ == pkt_type
|
||||
assert data == pkt_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_incorrect_key():
|
||||
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
incoming_packets = [
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_incorrect_name():
|
||||
"""Test we raise on bad name."""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
incoming_packets = [
|
||||
"01000d01736572766963657465737400",
|
||||
"0100160148616e647368616b65204d4143206661696c757265",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
await helper.perform_handshake()
|
||||
|
|
Loading…
Reference in New Issue