Refactor frame_helper into new module (#109)

This commit is contained in:
Otto Winter 2021-10-13 10:05:08 +02:00 committed by GitHub
parent 48ea96f9da
commit 9ca228cd1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 267 additions and 219 deletions

View File

@ -0,0 +1,254 @@
import asyncio
import base64
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from noise.connection import NoiseConnection # type: ignore
from .core import (
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
SocketAPIError,
SocketClosedAPIError,
)
from .util import bytes_to_varuint, varuint_to_bytes
_LOGGER = logging.getLogger(__name__)
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper(ABC):
@abstractmethod
async def close(self) -> None:
pass
@abstractmethod
async def write_packet(self, packet: Packet) -> None:
pass
@abstractmethod
async def read_packet(self) -> Packet:
pass
class APIPlaintextFrameHelper(APIFrameHelper):
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
self._reader = reader
self._writer = writer
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._closed_event = asyncio.Event()
async def close(self) -> None:
self._closed_event.set()
self._writer.close()
async def write_packet(self, packet: Packet) -> None:
data = b"\0"
data += varuint_to_bytes(len(packet.data))
data += varuint_to_bytes(packet.type)
data += packet.data
try:
async with self._write_lock:
_LOGGER.debug("Sending plaintext frame %s", data.hex())
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def read_packet(self) -> Packet:
async with self._read_lock:
try:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
if preamble[0] == 0x01:
raise RequiresEncryptionAPIError(
"Connection requires encryption"
)
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
length = b""
while not length or (length[-1] & 0x80) == 0x80:
length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type = b""
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
raw_msg = b""
if length_int != 0:
raw_msg = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=raw_msg)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
def _decode_noise_psk(psk: str) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
try:
psk_bytes = base64.b64decode(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected base64-encoded value"
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
)
return psk_bytes
class APINoiseFrameHelper(APIFrameHelper):
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
noise_psk: str,
):
self._reader = reader
self._writer = writer
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._closed_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
self._noise_psk = noise_psk
async def close(self) -> None:
self._closed_event.set()
self._writer.close()
async def _write_frame(self, frame: bytes) -> None:
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", frame.hex())
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex())
return frame
async def perform_handshake(self) -> None:
await self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame() # ServerHello
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
self._proto.set_prologue(prologue)
self._proto.start_handshake()
_LOGGER.debug("Starting handshake...")
do_write = True
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
await self._write_frame(b"\x00" + msg)
else:
msg = await self._read_frame()
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
self._ready_event.set()
async def write_packet(self, packet: Packet) -> None:
# Wait for handshake to complete
await self._ready_event.wait()
padding = 0
data = (
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
+ b"\x00" * padding
)
assert self._proto is not None
frame = self._proto.encrypt(data)
await self._write_frame(frame)
async def read_packet(self) -> Packet:
# Wait for handshake to complete
await self._ready_event.wait()
frame = await self._read_frame()
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import base64
import enum import enum
import logging import logging
import socket import socket
@ -9,10 +8,15 @@ from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional from typing import Any, Awaitable, Callable, List, Optional
from google.protobuf import message from google.protobuf import message
from noise.connection import NoiseConnection # type: ignore
import aioesphomeapi.host_resolver as hr import aioesphomeapi.host_resolver as hr
from ._frame_helper import (
APIFrameHelper,
APINoiseFrameHelper,
APIPlaintextFrameHelper,
Packet,
)
from .api_pb2 import ( # type: ignore from .api_pb2 import ( # type: ignore
ConnectRequest, ConnectRequest,
ConnectResponse, ConnectResponse,
@ -28,20 +32,16 @@ from .api_pb2 import ( # type: ignore
from .core import ( from .core import (
MESSAGE_TYPE_TO_PROTO, MESSAGE_TYPE_TO_PROTO,
APIConnectionError, APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError, InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
PingFailedAPIError, PingFailedAPIError,
ProtocolAPIError, ProtocolAPIError,
ReadFailedAPIError, ReadFailedAPIError,
RequiresEncryptionAPIError,
ResolveAPIError, ResolveAPIError,
SocketAPIError, SocketAPIError,
SocketClosedAPIError, SocketClosedAPIError,
TimeoutAPIError, TimeoutAPIError,
) )
from .model import APIVersion from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -58,217 +58,6 @@ class ConnectionParams:
noise_psk: Optional[str] noise_psk: Optional[str]
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper:
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
params: ConnectionParams,
):
self._reader = reader
self._writer = writer
self._params = params
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
self._closed_event = asyncio.Event()
async def close(self) -> None:
self._closed_event.set()
self._writer.close()
async def _write_frame_noise(self, frame: bytes) -> None:
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", frame.hex())
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame_noise(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex())
return frame
async def perform_handshake(self) -> None:
if self._params.noise_psk is None:
return
await self._write_frame_noise(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
try:
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
)
if len(noise_psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
)
self._proto.set_psks(noise_psk_bytes)
self._proto.set_prologue(prologue)
self._proto.start_handshake()
_LOGGER.debug("Starting handshake...")
do_write = True
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
await self._write_frame_noise(b"\x00" + msg)
else:
msg = await self._read_frame_noise()
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
self._ready_event.set()
async def _write_packet_noise(self, packet: Packet) -> None:
await self._ready_event.wait()
padding = 0
data = (
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
+ b"\x00" * padding
)
assert self._proto is not None
frame = self._proto.encrypt(data)
await self._write_frame_noise(frame)
async def _write_packet_plaintext(self, packet: Packet) -> None:
data = b"\0"
data += varuint_to_bytes(len(packet.data))
data += varuint_to_bytes(packet.type)
data += packet.data
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", data.hex())
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None:
await self._write_packet_plaintext(packet)
else:
await self._write_packet_noise(packet)
async def _read_packet_noise(self) -> Packet:
await self._ready_event.wait()
frame = await self._read_frame_noise()
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)
async def _read_packet_plaintext(self) -> Packet:
async with self._read_lock:
try:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
if preamble[0] == 0x01:
raise RequiresEncryptionAPIError(
"Connection requires encryption"
)
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
length = b""
while not length or (length[-1] & 0x80) == 0x80:
length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type = b""
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
raw_msg = b""
if length_int != 0:
raw_msg = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=raw_msg)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
async def read_packet(self) -> Packet:
if self._params.noise_psk is None:
return await self._read_packet_plaintext()
return await self._read_packet_noise()
class ConnectionState(enum.Enum): class ConnectionState(enum.Enum):
# The connection is initialized, but connect() wasn't called yet # The connection is initialized, but connect() wasn't called yet
INITIALIZED = 0 INITIALIZED = 0
@ -381,8 +170,13 @@ class APIConnection:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
reader, writer = await asyncio.open_connection(sock=self._socket) reader, writer = await asyncio.open_connection(sock=self._socket)
self._frame_helper = APIFrameHelper(reader, writer, self._params) if self._params.noise_psk is None:
await self._frame_helper.perform_handshake() self._frame_helper = APIPlaintextFrameHelper(reader, writer)
else:
fh = self._frame_helper = APINoiseFrameHelper(
reader, writer, self._params.noise_psk
)
await fh.perform_handshake()
self._connection_state = ConnectionState.SOCKET_OPENED self._connection_state = ConnectionState.SOCKET_OPENED