Add noise API transport support (#100)

This commit is contained in:
Otto Winter 2021-09-08 23:12:07 +02:00 committed by GitHub
parent 50656382f1
commit 015e9c8d5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 308 additions and 63 deletions

View File

@ -124,6 +124,7 @@ class APIClient:
client_info: str = "aioesphomeapi",
keepalive: float = 15.0,
zeroconf_instance: ZeroconfInstanceType = None,
noise_psk: Optional[str] = None,
):
self._params = ConnectionParams(
eventloop=eventloop,
@ -133,6 +134,7 @@ class APIClient:
client_info=client_info,
keepalive=keepalive,
zeroconf_instance=zeroconf_instance,
noise_psk=noise_psk,
)
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
@ -305,6 +307,7 @@ class APIClient:
self,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional[LogLevel] = None,
dump_config: Optional[bool] = None,
) -> None:
self._check_authenticated()
@ -315,6 +318,8 @@ class APIClient:
req = SubscribeLogsRequest()
if log_level is not None:
req.level = log_level
if dump_config is not None:
req.dump_config = dump_config
assert self._connection is not None
await self._connection.send_message_callback_response(req, on_msg)

View File

@ -1,11 +1,13 @@
import asyncio
import base64
import logging
import socket
import time
from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional, cast
from typing import Any, Awaitable, Callable, List, Optional
from google.protobuf import message
from noise.connection import NoiseConnection # type: ignore
import aioesphomeapi.host_resolver as hr
@ -37,6 +39,185 @@ class ConnectionParams:
client_info: str
keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str]
@property
def noise_psk_bytes(self) -> Optional[bytes]:
if self.noise_psk is None:
return None
return base64.b64decode(self.noise_psk)
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper:
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
params: ConnectionParams,
):
self._reader = reader
self._writer = writer
self._params = params
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
async def close(self) -> None:
async with self._write_lock:
self._writer.close()
async def _write_frame_noise(self, frame: bytes) -> None:
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", frame.hex())
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
async def _read_frame_noise(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex())
return frame
async def perform_handshake(self) -> None:
if self._params.noise_psk is None:
return
await self._write_frame_noise(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello
if not server_hello:
raise APIConnectionError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise APIConnectionError(
f"Unknown protocol selected by client {chosen_proto}"
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(self._params.noise_psk_bytes)
self._proto.set_prologue(prologue)
self._proto.start_handshake()
_LOGGER.debug("Starting handshake...")
do_write = True
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
await self._write_frame_noise(b"\x00" + msg)
else:
msg = await self._read_frame_noise()
if not msg or msg[0] != 0:
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
self._ready_event.set()
async def _write_packet_noise(self, packet: Packet) -> None:
await self._ready_event.wait()
padding = 0
data = (
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
+ b"\x00" * padding
)
assert self._proto is not None
frame = self._proto.encrypt(data)
await self._write_frame_noise(frame)
async def _write_packet_plaintext(self, packet: Packet) -> None:
data = b"\0"
data += varuint_to_bytes(len(packet.data))
data += varuint_to_bytes(packet.type)
data += packet.data
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", data.hex())
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None:
await self._write_packet_plaintext(packet)
else:
await self._write_packet_noise(packet)
async def _read_packet_noise(self) -> Packet:
await self._ready_event.wait()
frame = await self._read_frame_noise()
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise APIConnectionError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)
async def _read_packet_plaintext(self) -> Packet:
async with self._read_lock:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
length = b""
while not length or (length[-1] & 0x80) == 0x80:
length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type = b""
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
raw_msg = b""
if length_int != 0:
raw_msg = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=raw_msg)
async def read_packet(self) -> Packet:
if self._params.noise_psk is None:
return await self._read_packet_plaintext()
return await self._read_packet_noise()
class APIConnection:
@ -47,9 +228,7 @@ class APIConnection:
self.on_stop = on_stop
self._stopped = False
self._socket: Optional[socket.socket] = None
self._socket_reader: Optional[asyncio.StreamReader] = None
self._socket_writer: Optional[asyncio.StreamWriter] = None
self._write_lock = asyncio.Lock()
self._frame_helper: Optional[APIFrameHelper] = None
self._connected = False
self._authenticated = False
self._socket_connected = False
@ -58,15 +237,13 @@ class APIConnection:
self._message_handlers: List[Callable[[message.Message], None]] = []
self.log_name = params.address
self._ping_task: Optional[asyncio.Task[None]] = None
def _start_ping(self) -> None:
async def func() -> None:
while self._connected:
while True:
await asyncio.sleep(self._params.keepalive)
if not self._connected:
return
try:
await self.ping()
except APIConnectionError:
@ -74,18 +251,20 @@ class APIConnection:
await self._on_error()
return
self._params.eventloop.create_task(func())
self._ping_task = asyncio.create_task(func())
async def _close_socket(self) -> None:
if not self._socket_connected:
return
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.close()
self._socket_writer = None
self._socket_reader = None
if self._frame_helper is not None:
await self._frame_helper.close()
self._frame_helper = None
if self._socket is not None:
self._socket.close()
self._socket = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
self._socket_connected = False
self._connected = False
self._authenticated = False
@ -106,6 +285,7 @@ class APIConnection:
async def _on_error(self) -> None:
await self.stop(force=True)
# pylint: disable=too-many-statements
async def connect(self) -> None:
if self._stopped:
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
@ -154,19 +334,25 @@ class APIConnection:
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
_LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection(
sock=self._socket
)
reader, writer = await asyncio.open_connection(sock=self._socket)
self._frame_helper = APIFrameHelper(reader, writer, self._params)
self._socket_connected = True
try:
await self._frame_helper.perform_handshake()
except APIConnectionError:
await self._on_error()
raise
self._params.eventloop.create_task(self.run_forever())
hello = HelloRequest()
hello.client_info = self._params.client_info
try:
resp = await self.send_message_await_response(hello, HelloResponse)
except APIConnectionError as err:
except APIConnectionError:
await self._on_error()
raise err
raise
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self.log_name,
@ -213,21 +399,10 @@ class APIConnection:
def is_authenticated(self) -> bool:
return self._authenticated
async def _write(self, data: bytes) -> None:
# _LOGGER.debug("%s: Write: %s", self._params.address,
# ' '.join('{:02X}'.format(x) for x in data))
async def send_message(self, msg: message.Message) -> None:
if not self._socket_connected:
raise APIConnectionError("Socket is not connected")
try:
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.write(data)
await self._socket_writer.drain()
except OSError as err:
await self._on_error()
raise APIConnectionError("Error while writing data: {}".format(err))
async def send_message(self, msg: message.Message) -> None:
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
if isinstance(msg, klass):
break
@ -236,12 +411,14 @@ class APIConnection:
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
req = bytes([0])
req += varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable
req += varuint_to_bytes(message_type)
req += encoded
await self._write(req)
assert self._frame_helper is not None
await self._frame_helper.write_packet(
Packet(
type=message_type,
data=encoded,
)
)
async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None]
@ -254,7 +431,7 @@ class APIConnection:
send_msg: message.Message,
do_append: Callable[[Any], bool],
do_stop: Callable[[Any], bool],
timeout: float = 5.0,
timeout: float = 10.0,
) -> List[Any]:
fut = self._params.eventloop.create_future()
responses = []
@ -285,7 +462,7 @@ class APIConnection:
return responses
async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
) -> Any:
def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type)
@ -298,33 +475,12 @@ class APIConnection:
return res[0]
async def _recv(self, amount: int) -> bytes:
if amount == 0:
return bytes()
try:
assert self._socket_reader is not None
ret = await self._socket_reader.readexactly(amount)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError("Error while receiving data: {}".format(err))
return ret
async def _recv_varint(self) -> int:
raw = bytes()
while not raw or raw[-1] & 0x80:
raw += await self._recv(1)
return cast(int, bytes_to_varuint(raw))
async def _run_once(self) -> None:
preamble = await self._recv(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
assert self._frame_helper is not None
pkt = await self._frame_helper.read_packet()
length = await self._recv_varint()
msg_type = await self._recv_varint()
raw_msg = await self._recv(length)
msg_type = pkt.type
raw_msg = pkt.data
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self._params.address, msg_type
@ -360,6 +516,7 @@ class APIConnection:
"%s: Unexpected error while reading incoming messages: %s",
self.log_name,
err,
exc_info=True,
)
await self._on_error()
break

View File

@ -0,0 +1,81 @@
# Helper script and aioesphomeapi to view logs from an esphome device
import argparse
import asyncio
import logging
import sys
from datetime import datetime
from typing import List
import zeroconf
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.client import APIClient
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import LogLevel
from aioesphomeapi.reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def main(argv: List[str]) -> None:
parser = argparse.ArgumentParser("aioesphomeapi-logs")
parser.add_argument("--port", type=int, default=6053)
parser.add_argument("--password", type=str)
parser.add_argument("--noise-psk", type=str)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("address")
args = parser.parse_args(argv[1:])
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
cli = APIClient(
asyncio.get_event_loop(),
args.address,
args.port,
args.password or "",
noise_psk=args.noise_psk,
keepalive=10,
)
def on_log(msg: SubscribeLogsResponse) -> None:
time_ = datetime.now().time().strftime("[%H:%M:%S]")
text = msg.message
print(time_ + text.decode("utf8", "backslashreplace"))
has_connects = False
async def on_connect() -> None:
nonlocal has_connects
try:
await cli.subscribe_logs(
on_log,
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
dump_config=not has_connects,
)
has_connects = True
except APIConnectionError:
cli.disconnect()
async def on_disconnect() -> None:
_LOGGER.warning("Disconnected from API")
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=zeroconf.Zeroconf(),
)
await logic.start()
try:
while True:
await asyncio.sleep(60)
except KeyboardInterrupt:
await logic.stop()
if __name__ == "__main__":
sys.exit(asyncio.run(main(sys.argv)) or 0)

View File

@ -1,2 +1,3 @@
protobuf>=3.12.2,<4.0
zeroconf>=0.28.0,<1.0
noiseprotocol>=0.3.1,<1.0

View File

@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams:
client_info="Tests client",
keepalive=15.0,
zeroconf_instance=None,
noise_psk=None,
)