aioesphomeapi/aioesphomeapi/_frame_helper.py

400 lines
14 KiB
Python

import asyncio
import base64
import logging
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional, Union, cast
import async_timeout
from noise.connection import NoiseConnection # type: ignore
from .core import (
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
SocketAPIError,
SocketClosedAPIError,
)
from .util import bytes_to_varuint, varuint_to_bytes
_LOGGER = logging.getLogger(__name__)
SOCKET_ERRORS = (
ConnectionResetError,
asyncio.IncompleteReadError,
OSError,
TimeoutError,
)
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol."""
def __init__(
self,
on_pkt: Callable[[Packet], None],
on_error: Callable[[Exception], None],
) -> None:
"""Initialize the API frame helper."""
self._on_pkt = on_pkt
self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._pos = 0
def _init_read(self, length: int) -> Optional[bytearray]:
"""Start reading a packet from the buffer."""
self._pos = 0
return self._read_exactly(length)
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
if len(self._buffer) < new_pos:
return None
self._pos = new_pos
return self._buffer[original_pos:new_pos]
@abstractmethod
async def perform_handshake(self) -> None:
"""Perform the handshake."""
@abstractmethod
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._connected_event.set()
def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)
self.close()
def _handle_error(self, exc: Exception) -> None:
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]:
self._handle_error(SocketClosedAPIError("EOF received"))
return super().eof_received()
def close(self) -> None:
"""Close the connection."""
if self._transport:
self._transport.close()
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
self._on_pkt(Packet(type_, data))
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
assert self._transport is not None, "Transport should be set"
data = (
b"\0"
+ varuint_to_bytes(len(packet.data))
+ varuint_to_bytes(packet.type)
+ packet.data
)
_LOGGER.debug("Sending plaintext frame %s", data.hex())
try:
self._transport.write(data)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def perform_handshake(self) -> None:
"""Perform the handshake."""
await self._connected_event.wait()
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to readexactly
init_bytes = self._init_read(3)
assert init_bytes is not None, "Buffer should have at least 3 bytes"
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
self._handle_error_and_close(
RequiresEncryptionAPIError("Connection requires encryption")
)
return
self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
)
return
if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte
length = init_bytes[1:3]
msg_type = b""
else:
# This is the most common case with 99% of messages
# needing a single byte for length and type which means
# we avoid 2 calls to readexactly
length = init_bytes[1:2]
msg_type = init_bytes[2:3]
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
add_length = self._read_exactly(1)
if add_length is None:
return
length += add_length
# If the message length was longer than 1 byte, we need to read the
# message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
add_msg_type = self._read_exactly(1)
if add_msg_type is None:
return
msg_type += add_msg_type
length_int = bytes_to_varuint(bytes(length))
assert length_int is not None
msg_type_int = bytes_to_varuint(bytes(msg_type))
assert msg_type_int is not None
if length_int == 0:
self._callback_packet(msg_type_int, b"")
# If we have more data, continue processing
continue
packet_data = self._read_exactly(length_int)
if packet_data is None:
return
self._callback_packet(msg_type_int, bytes(packet_data))
# If we have more data, continue processing
def _decode_noise_psk(psk: str) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
try:
psk_bytes = base64.b64decode(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected base64-encoded value"
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
)
return psk_bytes
class NoiseConnectionState(Enum):
"""Noise connection state."""
HELLO = 1
HANDSHAKE = 2
READY = 3
CLOSED = 4
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
def __init__(
self,
on_pkt: Callable[[Packet], None],
on_error: Callable[[Exception], None],
noise_psk: str,
expected_name: Optional[str],
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error)
self._ready_event = asyncio.Event()
self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._setup_proto()
def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._ready_event.set()
self._state = NoiseConnectionState.CLOSED
super().close()
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
_LOGGER.debug("Sending frame %s", frame.hex())
assert self._transport is not None, "Transport is not set"
try:
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._transport.write(header + frame)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def perform_handshake(self) -> None:
"""Perform the handshake with the server."""
self._send_hello()
try:
async with async_timeout.timeout(60.0):
await self._ready_event.wait()
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
header = self._init_read(3)
assert header is not None, "Buffer should have at least 3 bytes"
if header[0] != 0x01:
self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
)
msg_size = (header[1] << 8) | header[2]
frame = self._read_exactly(msg_size)
if frame is None:
return
try:
self.STATE_TO_CALLABLE[self._state](self, frame)
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
del self._buffer[: self._pos]
def _send_hello(self) -> None:
"""Send a ClientHello to the server."""
self._write_frame(b"") # ClientHello
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
# First byte of server hello is the protocol the server chose
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
# exists.
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)
# Check name matches expected name (for noise sessions, this is done
# during hello phase before a connection is set up)
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
server_name_i = server_hello.find(b"\0", 1)
if server_name_i != -1:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
if self._expected_name is not None and self._expected_name != server_name:
raise BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
)
self._state = NoiseConnectionState.HANDSHAKE
self._send_handshake()
def _setup_proto(self) -> None:
"""Set up the noise protocol."""
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
self._proto.start_handshake()
def _send_handshake(self) -> None:
"""Send the handshake message."""
self._write_frame(b"\x00" + self._proto.write_message())
def _handle_handshake(self, msg: bytearray) -> None:
_LOGGER.debug("Starting handshake...")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
_LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY
self._ready_event.set()
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
self._write_frame(
self._proto.encrypt(
(
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
)
)
)
def _handle_frame(self, frame: bytearray) -> None:
"""Handle an incoming frame."""
assert self._proto is not None
msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return self._on_pkt(Packet(pkt_type, data))
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray
) -> None:
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError("Connection closed"))
STATE_TO_CALLABLE = {
NoiseConnectionState.HELLO: _handle_hello,
NoiseConnectionState.HANDSHAKE: _handle_handshake,
NoiseConnectionState.READY: _handle_frame,
NoiseConnectionState.CLOSED: _handle_closed,
}