mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Refactor frame_helper into new module (#109)
This commit is contained in:
parent
48ea96f9da
commit
9ca228cd1e
254
aioesphomeapi/_frame_helper.py
Normal file
254
aioesphomeapi/_frame_helper.py
Normal file
@ -0,0 +1,254 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from noise.connection import NoiseConnection # type: ignore
|
||||
|
||||
from .core import (
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
type: int
|
||||
data: bytes
|
||||
|
||||
|
||||
class APIFrameHelper(ABC):
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read_packet(self) -> Packet:
|
||||
pass
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
self._closed_event = asyncio.Event()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
data = b"\0"
|
||||
data += varuint_to_bytes(len(packet.data))
|
||||
data += varuint_to_bytes(packet.type)
|
||||
data += packet.data
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
async with self._read_lock:
|
||||
try:
|
||||
preamble = await self._reader.readexactly(1)
|
||||
if preamble[0] != 0x00:
|
||||
if preamble[0] == 0x01:
|
||||
raise RequiresEncryptionAPIError(
|
||||
"Connection requires encryption"
|
||||
)
|
||||
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
||||
|
||||
length = b""
|
||||
while not length or (length[-1] & 0x80) == 0x80:
|
||||
length += await self._reader.readexactly(1)
|
||||
length_int = bytes_to_varuint(length)
|
||||
assert length_int is not None
|
||||
msg_type = b""
|
||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||
msg_type += await self._reader.readexactly(1)
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
assert msg_type_int is not None
|
||||
|
||||
raw_msg = b""
|
||||
if length_int != 0:
|
||||
raw_msg = await self._reader.readexactly(length_int)
|
||||
return Packet(type=msg_type_int, data=raw_msg)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
|
||||
def _decode_noise_psk(psk: str) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
try:
|
||||
psk_bytes = base64.b64decode(psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected base64-encoded value"
|
||||
)
|
||||
if len(psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
|
||||
)
|
||||
return psk_bytes
|
||||
|
||||
|
||||
class APINoiseFrameHelper(APIFrameHelper):
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
noise_psk: str,
|
||||
):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
self._ready_event = asyncio.Event()
|
||||
self._closed_event = asyncio.Event()
|
||||
self._proto: Optional[NoiseConnection] = None
|
||||
self._noise_psk = noise_psk
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def _write_frame(self, frame: bytes) -> None:
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
(len(frame) >> 8) & 0xFF,
|
||||
len(frame) & 0xFF,
|
||||
]
|
||||
)
|
||||
self._writer.write(header + frame)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def _read_frame(self) -> bytes:
|
||||
try:
|
||||
async with self._read_lock:
|
||||
header = await self._reader.readexactly(3)
|
||||
if header[0] != 0x01:
|
||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = await self._reader.readexactly(msg_size)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
return frame
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
await self._write_frame(b"") # ClientHello
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
server_hello = await self._read_frame() # ServerHello
|
||||
if not server_hello:
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise HandshakeAPIError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
|
||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||
self._proto.set_prologue(prologue)
|
||||
self._proto.start_handshake()
|
||||
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
do_write = True
|
||||
while not self._proto.handshake_finished:
|
||||
if do_write:
|
||||
msg = self._proto.write_message()
|
||||
await self._write_frame(b"\x00" + msg)
|
||||
else:
|
||||
msg = await self._read_frame()
|
||||
if not msg:
|
||||
raise HandshakeAPIError("Handshake message too short")
|
||||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
self._proto.read_message(msg[1:])
|
||||
|
||||
do_write = not do_write
|
||||
|
||||
_LOGGER.debug("Handshake complete!")
|
||||
self._ready_event.set()
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
# Wait for handshake to complete
|
||||
await self._ready_event.wait()
|
||||
padding = 0
|
||||
data = (
|
||||
bytes(
|
||||
[
|
||||
(packet.type >> 8) & 0xFF,
|
||||
(packet.type >> 0) & 0xFF,
|
||||
(len(packet.data) >> 8) & 0xFF,
|
||||
(len(packet.data) >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ packet.data
|
||||
+ b"\x00" * padding
|
||||
)
|
||||
assert self._proto is not None
|
||||
frame = self._proto.encrypt(data)
|
||||
await self._write_frame(frame)
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
# Wait for handshake to complete
|
||||
await self._ready_event.wait()
|
||||
frame = await self._read_frame()
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(frame)
|
||||
if len(msg) < 4:
|
||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
data = msg[4 : 4 + data_len]
|
||||
return Packet(type=pkt_type, data=data)
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
@ -9,10 +8,15 @@ from dataclasses import astuple, dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from google.protobuf import message
|
||||
from noise.connection import NoiseConnection # type: ignore
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
||||
from ._frame_helper import (
|
||||
APIFrameHelper,
|
||||
APINoiseFrameHelper,
|
||||
APIPlaintextFrameHelper,
|
||||
Packet,
|
||||
)
|
||||
from .api_pb2 import ( # type: ignore
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
@ -28,20 +32,16 @@ from .api_pb2 import ( # type: ignore
|
||||
from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
PingFailedAPIError,
|
||||
ProtocolAPIError,
|
||||
ReadFailedAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
from .model import APIVersion
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -58,217 +58,6 @@ class ConnectionParams:
|
||||
noise_psk: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
type: int
|
||||
data: bytes
|
||||
|
||||
|
||||
class APIFrameHelper:
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
params: ConnectionParams,
|
||||
):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._params = params
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
self._ready_event = asyncio.Event()
|
||||
self._proto: Optional[NoiseConnection] = None
|
||||
self._closed_event = asyncio.Event()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def _write_frame_noise(self, frame: bytes) -> None:
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
(len(frame) >> 8) & 0xFF,
|
||||
len(frame) & 0xFF,
|
||||
]
|
||||
)
|
||||
self._writer.write(header + frame)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def _read_frame_noise(self) -> bytes:
|
||||
try:
|
||||
async with self._read_lock:
|
||||
header = await self._reader.readexactly(3)
|
||||
if header[0] != 0x01:
|
||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = await self._reader.readexactly(msg_size)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
return frame
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
if self._params.noise_psk is None:
|
||||
return
|
||||
await self._write_frame_noise(b"") # ClientHello
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
server_hello = await self._read_frame_noise() # ServerHello
|
||||
if not server_hello:
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise HandshakeAPIError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
|
||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
self._proto.set_as_initiator()
|
||||
|
||||
try:
|
||||
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
|
||||
)
|
||||
if len(noise_psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
|
||||
)
|
||||
|
||||
self._proto.set_psks(noise_psk_bytes)
|
||||
self._proto.set_prologue(prologue)
|
||||
self._proto.start_handshake()
|
||||
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
do_write = True
|
||||
while not self._proto.handshake_finished:
|
||||
if do_write:
|
||||
msg = self._proto.write_message()
|
||||
await self._write_frame_noise(b"\x00" + msg)
|
||||
else:
|
||||
msg = await self._read_frame_noise()
|
||||
if not msg:
|
||||
raise HandshakeAPIError("Handshake message too short")
|
||||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
self._proto.read_message(msg[1:])
|
||||
|
||||
do_write = not do_write
|
||||
|
||||
_LOGGER.debug("Handshake complete!")
|
||||
self._ready_event.set()
|
||||
|
||||
async def _write_packet_noise(self, packet: Packet) -> None:
|
||||
await self._ready_event.wait()
|
||||
padding = 0
|
||||
data = (
|
||||
bytes(
|
||||
[
|
||||
(packet.type >> 8) & 0xFF,
|
||||
(packet.type >> 0) & 0xFF,
|
||||
(len(packet.data) >> 8) & 0xFF,
|
||||
(len(packet.data) >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ packet.data
|
||||
+ b"\x00" * padding
|
||||
)
|
||||
assert self._proto is not None
|
||||
frame = self._proto.encrypt(data)
|
||||
await self._write_frame_noise(frame)
|
||||
|
||||
async def _write_packet_plaintext(self, packet: Packet) -> None:
|
||||
data = b"\0"
|
||||
data += varuint_to_bytes(len(packet.data))
|
||||
data += varuint_to_bytes(packet.type)
|
||||
data += packet.data
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending frame %s", data.hex())
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
if self._params.noise_psk is None:
|
||||
await self._write_packet_plaintext(packet)
|
||||
else:
|
||||
await self._write_packet_noise(packet)
|
||||
|
||||
async def _read_packet_noise(self) -> Packet:
|
||||
await self._ready_event.wait()
|
||||
frame = await self._read_frame_noise()
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(frame)
|
||||
if len(msg) < 4:
|
||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
data = msg[4 : 4 + data_len]
|
||||
return Packet(type=pkt_type, data=data)
|
||||
|
||||
async def _read_packet_plaintext(self) -> Packet:
|
||||
async with self._read_lock:
|
||||
try:
|
||||
preamble = await self._reader.readexactly(1)
|
||||
if preamble[0] != 0x00:
|
||||
if preamble[0] == 0x01:
|
||||
raise RequiresEncryptionAPIError(
|
||||
"Connection requires encryption"
|
||||
)
|
||||
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
||||
|
||||
length = b""
|
||||
while not length or (length[-1] & 0x80) == 0x80:
|
||||
length += await self._reader.readexactly(1)
|
||||
length_int = bytes_to_varuint(length)
|
||||
assert length_int is not None
|
||||
msg_type = b""
|
||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||
msg_type += await self._reader.readexactly(1)
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
assert msg_type_int is not None
|
||||
|
||||
raw_msg = b""
|
||||
if length_int != 0:
|
||||
raw_msg = await self._reader.readexactly(length_int)
|
||||
return Packet(type=msg_type_int, data=raw_msg)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
if self._params.noise_psk is None:
|
||||
return await self._read_packet_plaintext()
|
||||
return await self._read_packet_noise()
|
||||
|
||||
|
||||
class ConnectionState(enum.Enum):
|
||||
# The connection is initialized, but connect() wasn't called yet
|
||||
INITIALIZED = 0
|
||||
@ -381,8 +170,13 @@ class APIConnection:
|
||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||
|
||||
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
||||
await self._frame_helper.perform_handshake()
|
||||
if self._params.noise_psk is None:
|
||||
self._frame_helper = APIPlaintextFrameHelper(reader, writer)
|
||||
else:
|
||||
fh = self._frame_helper = APINoiseFrameHelper(
|
||||
reader, writer, self._params.noise_psk
|
||||
)
|
||||
await fh.perform_handshake()
|
||||
|
||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user