mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-06-26 10:35:00 +02:00
Fix more cases where incorrect encryption keys were not detected (#447)
This commit is contained in:
parent
c1752bcf49
commit
eaa5e295cf
|
@ -6,9 +6,11 @@ from enum import Enum
|
||||||
from typing import Callable, Optional, Union, cast
|
from typing import Callable, Optional, Union, cast
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
|
from cryptography.exceptions import InvalidTag
|
||||||
from noise.connection import NoiseConnection # type: ignore
|
from noise.connection import NoiseConnection # type: ignore
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
|
APIConnectionError,
|
||||||
BadNameAPIError,
|
BadNameAPIError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
|
@ -219,29 +221,38 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
super().__init__(on_pkt, on_error)
|
super().__init__(on_pkt, on_error)
|
||||||
self._ready_event = asyncio.Event()
|
self._ready_future = asyncio.get_event_loop().create_future()
|
||||||
self._noise_psk = noise_psk
|
self._noise_psk = noise_psk
|
||||||
self._expected_name = expected_name
|
self._expected_name = expected_name
|
||||||
self._state = NoiseConnectionState.HELLO
|
self._state = NoiseConnectionState.HELLO
|
||||||
self._setup_proto()
|
self._setup_proto()
|
||||||
|
|
||||||
|
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||||
|
if not self._ready_future.done():
|
||||||
|
self._ready_future.set_exception(exc)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the connection."""
|
"""Close the connection."""
|
||||||
# Make sure we set the ready event if its not already set
|
# Make sure we set the ready event if its not already set
|
||||||
# so that we don't block forever on the ready event if we
|
# so that we don't block forever on the ready event if we
|
||||||
# are waiting for the handshake to complete.
|
# are waiting for the handshake to complete.
|
||||||
self._ready_event.set()
|
self._set_ready_future_exception(APIConnectionError("Connection closed"))
|
||||||
self._state = NoiseConnectionState.CLOSED
|
self._state = NoiseConnectionState.CLOSED
|
||||||
super().close()
|
super().close()
|
||||||
|
|
||||||
|
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||||
|
self._set_ready_future_exception(exc)
|
||||||
|
super()._handle_error_and_close(exc)
|
||||||
|
|
||||||
def _write_frame(self, frame: bytes) -> None:
|
def _write_frame(self, frame: bytes) -> None:
|
||||||
"""Write a packet to the socket, the caller should not have the lock.
|
"""Write a packet to the socket, the caller should not have the lock.
|
||||||
|
|
||||||
The entire packet must be written in a single call to write
|
The entire packet must be written in a single call to write
|
||||||
to avoid locking.
|
to avoid locking.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
|
||||||
assert self._transport is not None, "Transport is not set"
|
assert self._transport is not None, "Transport is not set"
|
||||||
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
|
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = bytes(
|
header = bytes(
|
||||||
|
@ -260,7 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
self._send_hello()
|
self._send_hello()
|
||||||
try:
|
try:
|
||||||
async with async_timeout.timeout(60.0):
|
async with async_timeout.timeout(60.0):
|
||||||
await self._ready_event.wait()
|
await self._ready_future
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
raise HandshakeAPIError("Timeout during handshake") from err
|
raise HandshakeAPIError("Timeout during handshake") from err
|
||||||
|
|
||||||
|
@ -273,8 +284,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
)
|
)
|
||||||
|
return
|
||||||
msg_size = (header[1] << 8) | header[2]
|
msg_size = (header[1] << 8) | header[2]
|
||||||
frame = self._read_exactly(msg_size)
|
frame = self._read_exactly(msg_size)
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -292,16 +305,18 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
raise HandshakeAPIError("ServerHello is empty")
|
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
||||||
|
return
|
||||||
|
|
||||||
# First byte of server hello is the protocol the server chose
|
# First byte of server hello is the protocol the server chose
|
||||||
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
||||||
# exists.
|
# exists.
|
||||||
chosen_proto = server_hello[0]
|
chosen_proto = server_hello[0]
|
||||||
if chosen_proto != 0x01:
|
if chosen_proto != 0x01:
|
||||||
raise HandshakeAPIError(
|
self._handle_error_and_close(
|
||||||
f"Unknown protocol selected by client {chosen_proto}"
|
HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Check name matches expected name (for noise sessions, this is done
|
# Check name matches expected name (for noise sessions, this is done
|
||||||
# during hello phase before a connection is set up)
|
# during hello phase before a connection is set up)
|
||||||
|
@ -311,9 +326,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
# server name found, this extension was added in 2022.2
|
# server name found, this extension was added in 2022.2
|
||||||
server_name = server_hello[1:server_name_i].decode()
|
server_name = server_hello[1:server_name_i].decode()
|
||||||
if self._expected_name is not None and self._expected_name != server_name:
|
if self._expected_name is not None and self._expected_name != server_name:
|
||||||
raise BadNameAPIError(
|
self._handle_error_and_close(
|
||||||
f"Server sent a different name '{server_name}'", server_name
|
BadNameAPIError(
|
||||||
|
f"Server sent a different name '{server_name}'", server_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self._state = NoiseConnectionState.HANDSHAKE
|
self._state = NoiseConnectionState.HANDSHAKE
|
||||||
self._send_handshake()
|
self._send_handshake()
|
||||||
|
@ -335,12 +353,24 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
if msg[0] != 0:
|
if msg[0] != 0:
|
||||||
explanation = msg[1:].decode()
|
explanation = msg[1:].decode()
|
||||||
if explanation == "Handshake MAC failure":
|
if explanation == "Handshake MAC failure":
|
||||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
self._handle_error_and_close(
|
||||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||||
self._proto.read_message(msg[1:])
|
)
|
||||||
|
return
|
||||||
|
self._handle_error_and_close(
|
||||||
|
HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._proto.read_message(msg[1:])
|
||||||
|
except InvalidTag as invalid_tag_exc:
|
||||||
|
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||||
|
ex.__cause__ = invalid_tag_exc
|
||||||
|
self._handle_error_and_close(ex)
|
||||||
|
return
|
||||||
_LOGGER.debug("Handshake complete")
|
_LOGGER.debug("Handshake complete")
|
||||||
self._state = NoiseConnectionState.READY
|
self._state = NoiseConnectionState.READY
|
||||||
self._ready_event.set()
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packet(self, type_: int, data: bytes) -> None:
|
def write_packet(self, type_: int, data: bytes) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packet to the socket."""
|
||||||
|
@ -367,13 +397,17 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
assert self._proto is not None
|
assert self._proto is not None
|
||||||
msg = self._proto.decrypt(bytes(frame))
|
msg = self._proto.decrypt(bytes(frame))
|
||||||
if len(msg) < 4:
|
if len(msg) < 4:
|
||||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
|
||||||
|
return
|
||||||
pkt_type = (msg[0] << 8) | msg[1]
|
pkt_type = (msg[0] << 8) | msg[1]
|
||||||
data_len = (msg[2] << 8) | msg[3]
|
data_len = (msg[2] << 8) | msg[3]
|
||||||
if data_len + 4 > len(msg):
|
if data_len + 4 > len(msg):
|
||||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
self._handle_error_and_close(
|
||||||
|
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||||
|
)
|
||||||
|
return
|
||||||
data = msg[4 : 4 + data_len]
|
data = msg[4 : 4 + data_len]
|
||||||
return self._on_pkt(pkt_type, data)
|
self._on_pkt(pkt_type, data)
|
||||||
|
|
||||||
def _handle_closed( # pylint: disable=unused-argument
|
def _handle_closed( # pylint: disable=unused-argument
|
||||||
self, frame: bytearray
|
self, frame: bytearray
|
||||||
|
|
|
@ -30,6 +30,7 @@ from .core import (
|
||||||
MESSAGE_TYPE_TO_PROTO,
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
BadNameAPIError,
|
BadNameAPIError,
|
||||||
|
ConnectionNotEstablishedAPIError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
PingFailedAPIError,
|
PingFailedAPIError,
|
||||||
|
@ -432,7 +433,7 @@ class APIConnection:
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
# Always clean up the connection if an error occured during connect
|
# Always clean up the connection if an error occurred during connect
|
||||||
self._connection_state = ConnectionState.CLOSED
|
self._connection_state = ConnectionState.CLOSED
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
raise
|
raise
|
||||||
|
@ -493,7 +494,12 @@ class APIConnection:
|
||||||
def send_message(self, msg: message.Message) -> None:
|
def send_message(self, msg: message.Message) -> None:
|
||||||
"""Send a protobuf message to the remote."""
|
"""Send a protobuf message to the remote."""
|
||||||
if not self._is_socket_open:
|
if not self._is_socket_open:
|
||||||
raise APIConnectionError(
|
if in_do_connect.get(False):
|
||||||
|
# If we are in the do_connect task, we can't raise an error
|
||||||
|
# because it would obscure the original exception (ie encrypt error).
|
||||||
|
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
|
||||||
|
return
|
||||||
|
raise ConnectionNotEstablishedAPIError(
|
||||||
f"Connection isn't established yet ({self._connection_state})"
|
f"Connection isn't established yet ({self._connection_state})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -183,6 +183,10 @@ class HandshakeAPIError(APIConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionNotEstablishedAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BadNameAPIError(APIConnectionError):
|
class BadNameAPIError(APIConnectionError):
|
||||||
"""Raised when a name received from the remote but does not much the expected name."""
|
"""Raised when a name received from the remote but does not much the expected name."""
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,17 @@ from typing import Awaitable, Callable, List, Optional
|
||||||
import zeroconf
|
import zeroconf
|
||||||
|
|
||||||
from .client import APIClient
|
from .client import APIClient
|
||||||
from .core import APIConnectionError
|
from .core import (
|
||||||
|
APIConnectionError,
|
||||||
|
InvalidAuthAPIError,
|
||||||
|
InvalidEncryptionKeyAPIError,
|
||||||
|
RequiresEncryptionAPIError,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
EXPECTED_DISCONNECT_COOLDOWN = 3.0
|
EXPECTED_DISCONNECT_COOLDOWN = 3.0
|
||||||
|
MAXIMUM_BACKOFF_TRIES = 100
|
||||||
|
|
||||||
|
|
||||||
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||||
|
@ -103,13 +109,26 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||||
level = logging.WARNING if self._tries == 0 else logging.DEBUG
|
level = logging.WARNING if self._tries == 0 else logging.DEBUG
|
||||||
_LOGGER.log(
|
_LOGGER.log(
|
||||||
level,
|
level,
|
||||||
"Can't connect to ESPHome API for %s: %s",
|
"Can't connect to ESPHome API for %s: %s (%s)",
|
||||||
self._log_name,
|
self._log_name,
|
||||||
err,
|
err,
|
||||||
|
type(err).__name__,
|
||||||
# Print stacktrace if unhandled (not APIConnectionError)
|
# Print stacktrace if unhandled (not APIConnectionError)
|
||||||
exc_info=not isinstance(err, APIConnectionError),
|
exc_info=not isinstance(err, APIConnectionError),
|
||||||
)
|
)
|
||||||
self._tries += 1
|
if isinstance(
|
||||||
|
err,
|
||||||
|
(
|
||||||
|
RequiresEncryptionAPIError,
|
||||||
|
InvalidEncryptionKeyAPIError,
|
||||||
|
InvalidAuthAPIError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# If we get an encryption or password error,
|
||||||
|
# backoff for the maximum amount of time
|
||||||
|
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||||
|
else:
|
||||||
|
self._tries += 1
|
||||||
return False
|
return False
|
||||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
|
@ -3,7 +3,8 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
|
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
|
||||||
from aioesphomeapi.util import varuint_to_bytes
|
from aioesphomeapi.util import varuint_to_bytes
|
||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
|
@ -63,3 +64,79 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||||
|
|
||||||
assert type_ == pkt_type
|
assert type_ == pkt_type
|
||||||
assert data == pkt_data
|
assert data == pkt_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noise_frame_helper_incorrect_key():
|
||||||
|
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
|
||||||
|
outgoing_packets = [
|
||||||
|
"010000", # hello packet
|
||||||
|
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||||
|
]
|
||||||
|
incoming_packets = [
|
||||||
|
"01000d01736572766963657465737400",
|
||||||
|
"0100160148616e647368616b65204d4143206661696c757265",
|
||||||
|
]
|
||||||
|
packets = []
|
||||||
|
|
||||||
|
def _packet(type_: int, data: bytes):
|
||||||
|
packets.append((type_, data))
|
||||||
|
|
||||||
|
def _on_error(exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
helper = APINoiseFrameHelper(
|
||||||
|
on_pkt=_packet,
|
||||||
|
on_error=_on_error,
|
||||||
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
|
expected_name="servicetest",
|
||||||
|
)
|
||||||
|
helper._transport = MagicMock()
|
||||||
|
|
||||||
|
for pkt in outgoing_packets:
|
||||||
|
helper._write_frame(bytes.fromhex(pkt))
|
||||||
|
|
||||||
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
|
for pkt in incoming_packets:
|
||||||
|
helper.data_received(bytes.fromhex(pkt))
|
||||||
|
|
||||||
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
|
await helper.perform_handshake()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noise_incorrect_name():
|
||||||
|
"""Test we raise on bad name."""
|
||||||
|
outgoing_packets = [
|
||||||
|
"010000", # hello packet
|
||||||
|
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||||
|
]
|
||||||
|
incoming_packets = [
|
||||||
|
"01000d01736572766963657465737400",
|
||||||
|
"0100160148616e647368616b65204d4143206661696c757265",
|
||||||
|
]
|
||||||
|
packets = []
|
||||||
|
|
||||||
|
def _packet(type_: int, data: bytes):
|
||||||
|
packets.append((type_, data))
|
||||||
|
|
||||||
|
def _on_error(exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
helper = APINoiseFrameHelper(
|
||||||
|
on_pkt=_packet,
|
||||||
|
on_error=_on_error,
|
||||||
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
|
expected_name="wrongname",
|
||||||
|
)
|
||||||
|
helper._transport = MagicMock()
|
||||||
|
|
||||||
|
for pkt in outgoing_packets:
|
||||||
|
helper._write_frame(bytes.fromhex(pkt))
|
||||||
|
|
||||||
|
with pytest.raises(BadNameAPIError):
|
||||||
|
for pkt in incoming_packets:
|
||||||
|
helper.data_received(bytes.fromhex(pkt))
|
||||||
|
|
||||||
|
with pytest.raises(BadNameAPIError):
|
||||||
|
await helper.perform_handshake()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user