mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-17 01:51:23 +01:00
Add noise API transport support (#100)
This commit is contained in:
parent
50656382f1
commit
015e9c8d5e
@ -124,6 +124,7 @@ class APIClient:
|
||||
client_info: str = "aioesphomeapi",
|
||||
keepalive: float = 15.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
noise_psk: Optional[str] = None,
|
||||
):
|
||||
self._params = ConnectionParams(
|
||||
eventloop=eventloop,
|
||||
@ -133,6 +134,7 @@ class APIClient:
|
||||
client_info=client_info,
|
||||
keepalive=keepalive,
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
noise_psk=noise_psk,
|
||||
)
|
||||
self._connection: Optional[APIConnection] = None
|
||||
self._cached_name: Optional[str] = None
|
||||
@ -305,6 +307,7 @@ class APIClient:
|
||||
self,
|
||||
on_log: Callable[[SubscribeLogsResponse], None],
|
||||
log_level: Optional[LogLevel] = None,
|
||||
dump_config: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._check_authenticated()
|
||||
|
||||
@ -315,6 +318,8 @@ class APIClient:
|
||||
req = SubscribeLogsRequest()
|
||||
if log_level is not None:
|
||||
req.level = log_level
|
||||
if dump_config is not None:
|
||||
req.dump_config = dump_config
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_callback_response(req, on_msg)
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional, cast
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from google.protobuf import message
|
||||
from noise.connection import NoiseConnection # type: ignore
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
||||
@ -37,6 +39,185 @@ class ConnectionParams:
|
||||
client_info: str
|
||||
keepalive: float
|
||||
zeroconf_instance: hr.ZeroconfInstanceType
|
||||
noise_psk: Optional[str]
|
||||
|
||||
@property
|
||||
def noise_psk_bytes(self) -> Optional[bytes]:
|
||||
if self.noise_psk is None:
|
||||
return None
|
||||
return base64.b64decode(self.noise_psk)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
type: int
|
||||
data: bytes
|
||||
|
||||
|
||||
class APIFrameHelper:
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
params: ConnectionParams,
|
||||
):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._params = params
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._read_lock = asyncio.Lock()
|
||||
self._ready_event = asyncio.Event()
|
||||
self._proto: Optional[NoiseConnection] = None
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._write_lock:
|
||||
self._writer.close()
|
||||
|
||||
async def _write_frame_noise(self, frame: bytes) -> None:
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
(len(frame) >> 8) & 0xFF,
|
||||
len(frame) & 0xFF,
|
||||
]
|
||||
)
|
||||
self._writer.write(header + frame)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def _read_frame_noise(self) -> bytes:
|
||||
try:
|
||||
async with self._read_lock:
|
||||
header = await self._reader.readexactly(3)
|
||||
if header[0] != 0x01:
|
||||
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = await self._reader.readexactly(msg_size)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
raise APIConnectionError(f"Error while reading data: {err}") from err
|
||||
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
return frame
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
if self._params.noise_psk is None:
|
||||
return
|
||||
await self._write_frame_noise(b"") # ClientHello
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
server_hello = await self._read_frame_noise() # ServerHello
|
||||
if not server_hello:
|
||||
raise APIConnectionError("ServerHello is empty")
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise APIConnectionError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
|
||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(self._params.noise_psk_bytes)
|
||||
self._proto.set_prologue(prologue)
|
||||
self._proto.start_handshake()
|
||||
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
do_write = True
|
||||
while not self._proto.handshake_finished:
|
||||
if do_write:
|
||||
msg = self._proto.write_message()
|
||||
await self._write_frame_noise(b"\x00" + msg)
|
||||
else:
|
||||
msg = await self._read_frame_noise()
|
||||
if not msg or msg[0] != 0:
|
||||
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
|
||||
self._proto.read_message(msg[1:])
|
||||
|
||||
do_write = not do_write
|
||||
|
||||
_LOGGER.debug("Handshake complete!")
|
||||
self._ready_event.set()
|
||||
|
||||
async def _write_packet_noise(self, packet: Packet) -> None:
|
||||
await self._ready_event.wait()
|
||||
padding = 0
|
||||
data = (
|
||||
bytes(
|
||||
[
|
||||
(packet.type >> 8) & 0xFF,
|
||||
(packet.type >> 0) & 0xFF,
|
||||
(len(packet.data) >> 8) & 0xFF,
|
||||
(len(packet.data) >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
+ packet.data
|
||||
+ b"\x00" * padding
|
||||
)
|
||||
assert self._proto is not None
|
||||
frame = self._proto.encrypt(data)
|
||||
await self._write_frame_noise(frame)
|
||||
|
||||
async def _write_packet_plaintext(self, packet: Packet) -> None:
|
||||
data = b"\0"
|
||||
data += varuint_to_bytes(len(packet.data))
|
||||
data += varuint_to_bytes(packet.type)
|
||||
data += packet.data
|
||||
try:
|
||||
async with self._write_lock:
|
||||
_LOGGER.debug("Sending frame %s", data.hex())
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
if self._params.noise_psk is None:
|
||||
await self._write_packet_plaintext(packet)
|
||||
else:
|
||||
await self._write_packet_noise(packet)
|
||||
|
||||
async def _read_packet_noise(self) -> Packet:
|
||||
await self._ready_event.wait()
|
||||
frame = await self._read_frame_noise()
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(frame)
|
||||
if len(msg) < 4:
|
||||
raise APIConnectionError(f"Bad packet frame: {msg}")
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
data = msg[4 : 4 + data_len]
|
||||
return Packet(type=pkt_type, data=data)
|
||||
|
||||
async def _read_packet_plaintext(self) -> Packet:
|
||||
async with self._read_lock:
|
||||
preamble = await self._reader.readexactly(1)
|
||||
if preamble[0] != 0x00:
|
||||
raise APIConnectionError("Invalid preamble")
|
||||
|
||||
length = b""
|
||||
while not length or (length[-1] & 0x80) == 0x80:
|
||||
length += await self._reader.readexactly(1)
|
||||
length_int = bytes_to_varuint(length)
|
||||
assert length_int is not None
|
||||
msg_type = b""
|
||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||
msg_type += await self._reader.readexactly(1)
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
assert msg_type_int is not None
|
||||
|
||||
raw_msg = b""
|
||||
if length_int != 0:
|
||||
raw_msg = await self._reader.readexactly(length_int)
|
||||
return Packet(type=msg_type_int, data=raw_msg)
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
if self._params.noise_psk is None:
|
||||
return await self._read_packet_plaintext()
|
||||
return await self._read_packet_noise()
|
||||
|
||||
|
||||
class APIConnection:
|
||||
@ -47,9 +228,7 @@ class APIConnection:
|
||||
self.on_stop = on_stop
|
||||
self._stopped = False
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._socket_reader: Optional[asyncio.StreamReader] = None
|
||||
self._socket_writer: Optional[asyncio.StreamWriter] = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
self._connected = False
|
||||
self._authenticated = False
|
||||
self._socket_connected = False
|
||||
@ -58,15 +237,13 @@ class APIConnection:
|
||||
|
||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||
self.log_name = params.address
|
||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
def _start_ping(self) -> None:
|
||||
async def func() -> None:
|
||||
while self._connected:
|
||||
while True:
|
||||
await asyncio.sleep(self._params.keepalive)
|
||||
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ping()
|
||||
except APIConnectionError:
|
||||
@ -74,18 +251,20 @@ class APIConnection:
|
||||
await self._on_error()
|
||||
return
|
||||
|
||||
self._params.eventloop.create_task(func())
|
||||
self._ping_task = asyncio.create_task(func())
|
||||
|
||||
async def _close_socket(self) -> None:
|
||||
if not self._socket_connected:
|
||||
return
|
||||
async with self._write_lock:
|
||||
if self._socket_writer is not None:
|
||||
self._socket_writer.close()
|
||||
self._socket_writer = None
|
||||
self._socket_reader = None
|
||||
if self._frame_helper is not None:
|
||||
await self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
if self._ping_task is not None:
|
||||
self._ping_task.cancel()
|
||||
self._ping_task = None
|
||||
self._socket_connected = False
|
||||
self._connected = False
|
||||
self._authenticated = False
|
||||
@ -106,6 +285,7 @@ class APIConnection:
|
||||
async def _on_error(self) -> None:
|
||||
await self.stop(force=True)
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def connect(self) -> None:
|
||||
if self._stopped:
|
||||
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
|
||||
@ -154,19 +334,25 @@ class APIConnection:
|
||||
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
|
||||
|
||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||
self._socket_reader, self._socket_writer = await asyncio.open_connection(
|
||||
sock=self._socket
|
||||
)
|
||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
||||
self._socket_connected = True
|
||||
|
||||
try:
|
||||
await self._frame_helper.perform_handshake()
|
||||
except APIConnectionError:
|
||||
await self._on_error()
|
||||
raise
|
||||
|
||||
self._params.eventloop.create_task(self.run_forever())
|
||||
|
||||
hello = HelloRequest()
|
||||
hello.client_info = self._params.client_info
|
||||
try:
|
||||
resp = await self.send_message_await_response(hello, HelloResponse)
|
||||
except APIConnectionError as err:
|
||||
except APIConnectionError:
|
||||
await self._on_error()
|
||||
raise err
|
||||
raise
|
||||
_LOGGER.debug(
|
||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||
self.log_name,
|
||||
@ -213,21 +399,10 @@ class APIConnection:
|
||||
def is_authenticated(self) -> bool:
|
||||
return self._authenticated
|
||||
|
||||
async def _write(self, data: bytes) -> None:
|
||||
# _LOGGER.debug("%s: Write: %s", self._params.address,
|
||||
# ' '.join('{:02X}'.format(x) for x in data))
|
||||
async def send_message(self, msg: message.Message) -> None:
|
||||
if not self._socket_connected:
|
||||
raise APIConnectionError("Socket is not connected")
|
||||
try:
|
||||
async with self._write_lock:
|
||||
if self._socket_writer is not None:
|
||||
self._socket_writer.write(data)
|
||||
await self._socket_writer.drain()
|
||||
except OSError as err:
|
||||
await self._on_error()
|
||||
raise APIConnectionError("Error while writing data: {}".format(err))
|
||||
|
||||
async def send_message(self, msg: message.Message) -> None:
|
||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
||||
if isinstance(msg, klass):
|
||||
break
|
||||
@ -236,12 +411,14 @@ class APIConnection:
|
||||
|
||||
encoded = msg.SerializeToString()
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
req = bytes([0])
|
||||
req += varuint_to_bytes(len(encoded))
|
||||
# pylint: disable=undefined-loop-variable
|
||||
req += varuint_to_bytes(message_type)
|
||||
req += encoded
|
||||
await self._write(req)
|
||||
assert self._frame_helper is not None
|
||||
await self._frame_helper.write_packet(
|
||||
Packet(
|
||||
type=message_type,
|
||||
data=encoded,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message_callback_response(
|
||||
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
||||
@ -254,7 +431,7 @@ class APIConnection:
|
||||
send_msg: message.Message,
|
||||
do_append: Callable[[Any], bool],
|
||||
do_stop: Callable[[Any], bool],
|
||||
timeout: float = 5.0,
|
||||
timeout: float = 10.0,
|
||||
) -> List[Any]:
|
||||
fut = self._params.eventloop.create_future()
|
||||
responses = []
|
||||
@ -285,7 +462,7 @@ class APIConnection:
|
||||
return responses
|
||||
|
||||
async def send_message_await_response(
|
||||
self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
|
||||
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
||||
) -> Any:
|
||||
def is_response(msg: message.Message) -> bool:
|
||||
return isinstance(msg, response_type)
|
||||
@ -298,33 +475,12 @@ class APIConnection:
|
||||
|
||||
return res[0]
|
||||
|
||||
async def _recv(self, amount: int) -> bytes:
|
||||
if amount == 0:
|
||||
return bytes()
|
||||
|
||||
try:
|
||||
assert self._socket_reader is not None
|
||||
ret = await self._socket_reader.readexactly(amount)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
raise APIConnectionError("Error while receiving data: {}".format(err))
|
||||
|
||||
return ret
|
||||
|
||||
async def _recv_varint(self) -> int:
|
||||
raw = bytes()
|
||||
while not raw or raw[-1] & 0x80:
|
||||
raw += await self._recv(1)
|
||||
return cast(int, bytes_to_varuint(raw))
|
||||
|
||||
async def _run_once(self) -> None:
|
||||
preamble = await self._recv(1)
|
||||
if preamble[0] != 0x00:
|
||||
raise APIConnectionError("Invalid preamble")
|
||||
assert self._frame_helper is not None
|
||||
pkt = await self._frame_helper.read_packet()
|
||||
|
||||
length = await self._recv_varint()
|
||||
msg_type = await self._recv_varint()
|
||||
|
||||
raw_msg = await self._recv(length)
|
||||
msg_type = pkt.type
|
||||
raw_msg = pkt.data
|
||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s", self._params.address, msg_type
|
||||
@ -360,6 +516,7 @@ class APIConnection:
|
||||
"%s: Unexpected error while reading incoming messages: %s",
|
||||
self.log_name,
|
||||
err,
|
||||
exc_info=True,
|
||||
)
|
||||
await self._on_error()
|
||||
break
|
||||
|
81
aioesphomeapi/log_reader.py
Normal file
81
aioesphomeapi/log_reader.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Helper script and aioesphomeapi to view logs from an esphome device
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import zeroconf
|
||||
|
||||
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from aioesphomeapi.model import LogLevel
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main(argv: List[str]) -> None:
|
||||
parser = argparse.ArgumentParser("aioesphomeapi-logs")
|
||||
parser.add_argument("--port", type=int, default=6053)
|
||||
parser.add_argument("--password", type=str)
|
||||
parser.add_argument("--noise-psk", type=str)
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
parser.add_argument("address")
|
||||
args = parser.parse_args(argv[1:])
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)-8s %(message)s",
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
cli = APIClient(
|
||||
asyncio.get_event_loop(),
|
||||
args.address,
|
||||
args.port,
|
||||
args.password or "",
|
||||
noise_psk=args.noise_psk,
|
||||
keepalive=10,
|
||||
)
|
||||
|
||||
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||
time_ = datetime.now().time().strftime("[%H:%M:%S]")
|
||||
text = msg.message
|
||||
print(time_ + text.decode("utf8", "backslashreplace"))
|
||||
|
||||
has_connects = False
|
||||
|
||||
async def on_connect() -> None:
|
||||
nonlocal has_connects
|
||||
try:
|
||||
await cli.subscribe_logs(
|
||||
on_log,
|
||||
log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE,
|
||||
dump_config=not has_connects,
|
||||
)
|
||||
has_connects = True
|
||||
except APIConnectionError:
|
||||
cli.disconnect()
|
||||
|
||||
async def on_disconnect() -> None:
|
||||
_LOGGER.warning("Disconnected from API")
|
||||
|
||||
logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
zeroconf_instance=zeroconf.Zeroconf(),
|
||||
)
|
||||
await logic.start()
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
except KeyboardInterrupt:
|
||||
await logic.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main(sys.argv)) or 0)
|
@ -1,2 +1,3 @@
|
||||
protobuf>=3.12.2,<4.0
|
||||
zeroconf>=0.28.0,<1.0
|
||||
noiseprotocol>=0.3.1,<1.0
|
@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams:
|
||||
client_info="Tests client",
|
||||
keepalive=15.0,
|
||||
zeroconf_instance=None,
|
||||
noise_psk=None,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user