Fix more cases where incorrect encryption keys were not detected (#447)

This commit is contained in:
J. Nick Koston 2023-06-24 10:47:24 -05:00 committed by GitHub
parent c1752bcf49
commit eaa5e295cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 22 deletions

View File

@ -6,9 +6,11 @@ from enum import Enum
from typing import Callable, Optional, Union, cast from typing import Callable, Optional, Union, cast
import async_timeout import async_timeout
from cryptography.exceptions import InvalidTag
from noise.connection import NoiseConnection # type: ignore from noise.connection import NoiseConnection # type: ignore
from .core import ( from .core import (
APIConnectionError,
BadNameAPIError, BadNameAPIError,
HandshakeAPIError, HandshakeAPIError,
InvalidEncryptionKeyAPIError, InvalidEncryptionKeyAPIError,
@ -219,29 +221,38 @@ class APINoiseFrameHelper(APIFrameHelper):
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
super().__init__(on_pkt, on_error) super().__init__(on_pkt, on_error)
self._ready_event = asyncio.Event() self._ready_future = asyncio.get_event_loop().create_future()
self._noise_psk = noise_psk self._noise_psk = noise_psk
self._expected_name = expected_name self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO self._state = NoiseConnectionState.HELLO
self._setup_proto() self._setup_proto()
def _set_ready_future_exception(self, exc: Exception) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def close(self) -> None: def close(self) -> None:
"""Close the connection.""" """Close the connection."""
# Make sure we set the ready event if its not already set # Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we # so that we don't block forever on the ready event if we
# are waiting for the handshake to complete. # are waiting for the handshake to complete.
self._ready_event.set() self._set_ready_future_exception(APIConnectionError("Connection closed"))
self._state = NoiseConnectionState.CLOSED self._state = NoiseConnectionState.CLOSED
super().close() super().close()
def _handle_error_and_close(self, exc: Exception) -> None:
self._set_ready_future_exception(exc)
super()._handle_error_and_close(exc)
def _write_frame(self, frame: bytes) -> None: def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock. """Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write The entire packet must be written in a single call to write
to avoid locking. to avoid locking.
""" """
_LOGGER.debug("Sending frame %s", frame.hex())
assert self._transport is not None, "Transport is not set" assert self._transport is not None, "Transport is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending frame: [%s]", frame.hex())
try: try:
header = bytes( header = bytes(
@ -260,7 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._send_hello() self._send_hello()
try: try:
async with async_timeout.timeout(60.0): async with async_timeout.timeout(60.0):
await self._ready_event.wait() await self._ready_future
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err raise HandshakeAPIError("Timeout during handshake") from err
@ -273,8 +284,10 @@ class APINoiseFrameHelper(APIFrameHelper):
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}") ProtocolAPIError(f"Marker byte invalid: {header[0]}")
) )
return
msg_size = (header[1] << 8) | header[2] msg_size = (header[1] << 8) | header[2]
frame = self._read_exactly(msg_size) frame = self._read_exactly(msg_size)
if frame is None: if frame is None:
return return
@ -292,16 +305,18 @@ class APINoiseFrameHelper(APIFrameHelper):
def _handle_hello(self, server_hello: bytearray) -> None: def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock.""" """Perform the handshake with the server, the caller is responsible for having the lock."""
if not server_hello: if not server_hello:
raise HandshakeAPIError("ServerHello is empty") self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
return
# First byte of server hello is the protocol the server chose # First byte of server hello is the protocol the server chose
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256) # for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
# exists. # exists.
chosen_proto = server_hello[0] chosen_proto = server_hello[0]
if chosen_proto != 0x01: if chosen_proto != 0x01:
raise HandshakeAPIError( self._handle_error_and_close(
f"Unknown protocol selected by client {chosen_proto}" HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
) )
return
# Check name matches expected name (for noise sessions, this is done # Check name matches expected name (for noise sessions, this is done
# during hello phase before a connection is set up) # during hello phase before a connection is set up)
@ -311,9 +326,12 @@ class APINoiseFrameHelper(APIFrameHelper):
# server name found, this extension was added in 2022.2 # server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode() server_name = server_hello[1:server_name_i].decode()
if self._expected_name is not None and self._expected_name != server_name: if self._expected_name is not None and self._expected_name != server_name:
raise BadNameAPIError( self._handle_error_and_close(
f"Server sent a different name '{server_name}'", server_name BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
)
) )
return
self._state = NoiseConnectionState.HANDSHAKE self._state = NoiseConnectionState.HANDSHAKE
self._send_handshake() self._send_handshake()
@ -335,12 +353,24 @@ class APINoiseFrameHelper(APIFrameHelper):
if msg[0] != 0: if msg[0] != 0:
explanation = msg[1:].decode() explanation = msg[1:].decode()
if explanation == "Handshake MAC failure": if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key") self._handle_error_and_close(
raise HandshakeAPIError(f"Handshake failure: {explanation}") InvalidEncryptionKeyAPIError("Invalid encryption key")
self._proto.read_message(msg[1:]) )
return
self._handle_error_and_close(
HandshakeAPIError(f"Handshake failure: {explanation}")
)
return
try:
self._proto.read_message(msg[1:])
except InvalidTag as invalid_tag_exc:
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
ex.__cause__ = invalid_tag_exc
self._handle_error_and_close(ex)
return
_LOGGER.debug("Handshake complete") _LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY self._state = NoiseConnectionState.READY
self._ready_event.set() self._ready_future.set_result(None)
def write_packet(self, type_: int, data: bytes) -> None: def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket.""" """Write a packet to the socket."""
@ -367,13 +397,17 @@ class APINoiseFrameHelper(APIFrameHelper):
assert self._proto is not None assert self._proto is not None
msg = self._proto.decrypt(bytes(frame)) msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4: if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}") self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
return
pkt_type = (msg[0] << 8) | msg[1] pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3] data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg): if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
)
return
data = msg[4 : 4 + data_len] data = msg[4 : 4 + data_len]
return self._on_pkt(pkt_type, data) self._on_pkt(pkt_type, data)
def _handle_closed( # pylint: disable=unused-argument def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray self, frame: bytearray

View File

@ -30,6 +30,7 @@ from .core import (
MESSAGE_TYPE_TO_PROTO, MESSAGE_TYPE_TO_PROTO,
APIConnectionError, APIConnectionError,
BadNameAPIError, BadNameAPIError,
ConnectionNotEstablishedAPIError,
HandshakeAPIError, HandshakeAPIError,
InvalidAuthAPIError, InvalidAuthAPIError,
PingFailedAPIError, PingFailedAPIError,
@ -432,7 +433,7 @@ class APIConnection:
self._cleanup() self._cleanup()
raise self._fatal_exception or APIConnectionError("Connection cancelled") raise self._fatal_exception or APIConnectionError("Connection cancelled")
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occured during connect # Always clean up the connection if an error occurred during connect
self._connection_state = ConnectionState.CLOSED self._connection_state = ConnectionState.CLOSED
self._cleanup() self._cleanup()
raise raise
@ -493,7 +494,12 @@ class APIConnection:
def send_message(self, msg: message.Message) -> None: def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote.""" """Send a protobuf message to the remote."""
if not self._is_socket_open: if not self._is_socket_open:
raise APIConnectionError( if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
return
raise ConnectionNotEstablishedAPIError(
f"Connection isn't established yet ({self._connection_state})" f"Connection isn't established yet ({self._connection_state})"
) )

View File

@ -183,6 +183,10 @@ class HandshakeAPIError(APIConnectionError):
pass pass
class ConnectionNotEstablishedAPIError(APIConnectionError):
pass
class BadNameAPIError(APIConnectionError): class BadNameAPIError(APIConnectionError):
"""Raised when a name received from the remote but does not much the expected name.""" """Raised when a name received from the remote but does not much the expected name."""

View File

@ -5,11 +5,17 @@ from typing import Awaitable, Callable, List, Optional
import zeroconf import zeroconf
from .client import APIClient from .client import APIClient
from .core import APIConnectionError from .core import (
APIConnectionError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
RequiresEncryptionAPIError,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
EXPECTED_DISCONNECT_COOLDOWN = 3.0 EXPECTED_DISCONNECT_COOLDOWN = 3.0
MAXIMUM_BACKOFF_TRIES = 100
class ReconnectLogic(zeroconf.RecordUpdateListener): class ReconnectLogic(zeroconf.RecordUpdateListener):
@ -103,13 +109,26 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
level = logging.WARNING if self._tries == 0 else logging.DEBUG level = logging.WARNING if self._tries == 0 else logging.DEBUG
_LOGGER.log( _LOGGER.log(
level, level,
"Can't connect to ESPHome API for %s: %s", "Can't connect to ESPHome API for %s: %s (%s)",
self._log_name, self._log_name,
err, err,
type(err).__name__,
# Print stacktrace if unhandled (not APIConnectionError) # Print stacktrace if unhandled (not APIConnectionError)
exc_info=not isinstance(err, APIConnectionError), exc_info=not isinstance(err, APIConnectionError),
) )
self._tries += 1 if isinstance(
err,
(
RequiresEncryptionAPIError,
InvalidEncryptionKeyAPIError,
InvalidAuthAPIError,
),
):
# If we get an encryption or password error,
# backoff for the maximum amount of time
self._tries = MAXIMUM_BACKOFF_TRIES
else:
self._tries += 1
return False return False
_LOGGER.info("Successfully connected to %s", self._log_name) _LOGGER.info("Successfully connected to %s", self._log_name)
self._connected = True self._connected = True

View File

@ -3,7 +3,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
from aioesphomeapi.util import varuint_to_bytes from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -63,3 +64,79 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
assert type_ == pkt_type assert type_ == pkt_type
assert data == pkt_data assert data == pkt_data
@pytest.mark.asyncio
async def test_noise_frame_helper_incorrect_key():
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
)
helper._transport = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake()
@pytest.mark.asyncio
async def test_noise_incorrect_name():
"""Test we raise on bad name."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = APINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
)
helper._transport = MagicMock()
for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
for pkt in incoming_packets:
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
await helper.perform_handshake()