mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Simplify connection flow with an asyncio.Protocol (#352)
This commit is contained in:
parent
049dc8bb56
commit
2886d361f0
@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import abstractmethod, abstractproperty
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from enum import Enum
|
||||||
|
from typing import Callable, Optional, Union, cast
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from noise.connection import NoiseConnection # type: ignore
|
from noise.connection import NoiseConnection # type: ignore
|
||||||
@ -32,57 +33,93 @@ SOCKET_ERRORS = (
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Packet:
|
class Packet:
|
||||||
type: int
|
type: int
|
||||||
data: bytes
|
data: Union[bytes, bytearray]
|
||||||
|
|
||||||
|
|
||||||
class APIFrameHelper(ABC):
|
class APIFrameHelper(asyncio.Protocol):
|
||||||
"""Helper class to handle the API frame protocol."""
|
"""Helper class to handle the API frame protocol."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reader: asyncio.StreamReader,
|
on_pkt: Callable[[Packet], None],
|
||||||
writer: asyncio.StreamWriter,
|
on_error: Callable[[Exception], None],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
self._reader = reader
|
self._on_pkt = on_pkt
|
||||||
self._writer = writer
|
self._on_error = on_error
|
||||||
|
self._transport: Optional[asyncio.Transport] = None
|
||||||
self.read_lock = asyncio.Lock()
|
self.read_lock = asyncio.Lock()
|
||||||
self._closed_event = asyncio.Event()
|
self._closed_event = asyncio.Event()
|
||||||
|
self._connected_event = asyncio.Event()
|
||||||
|
self._buffer = bytearray()
|
||||||
|
self._pos = 0
|
||||||
|
|
||||||
@abstractproperty # pylint: disable=deprecated-decorator
|
@abstractproperty # pylint: disable=deprecated-decorator
|
||||||
def ready(self) -> bool:
|
def ready(self) -> bool:
|
||||||
"""Return if the connection is ready."""
|
"""Return if the connection is ready."""
|
||||||
|
|
||||||
|
def _init_read(self, length: int) -> Optional[bytearray]:
|
||||||
|
"""Start reading a packet from the buffer."""
|
||||||
|
self._pos = 0
|
||||||
|
return self._read_exactly(length)
|
||||||
|
|
||||||
|
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||||
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||||
|
original_pos = self._pos
|
||||||
|
new_pos = original_pos + length
|
||||||
|
if len(self._buffer) < new_pos:
|
||||||
|
return None
|
||||||
|
self._pos = new_pos
|
||||||
|
return self._buffer[original_pos:new_pos]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def close(self) -> None:
|
async def perform_handshake(self) -> None:
|
||||||
"""Close the connection."""
|
"""Perform the handshake."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def write_packet(self, packet: Packet) -> None:
|
def write_packet(self, packet: Packet) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packet to the socket."""
|
||||||
|
|
||||||
@abstractmethod
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
async def read_packet_with_lock(self) -> Packet:
|
"""Handle a new connection."""
|
||||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
self._transport = cast(asyncio.Transport, transport)
|
||||||
|
self._connected_event.set()
|
||||||
|
|
||||||
@abstractmethod
|
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||||
async def wait_for_ready(self) -> None:
|
self._handle_error(exc)
|
||||||
"""Wait for the connection to be ready."""
|
self.close()
|
||||||
|
|
||||||
|
def _handle_error(self, exc: Exception) -> None:
|
||||||
|
self._closed_event.set()
|
||||||
|
self._on_error(exc)
|
||||||
|
|
||||||
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||||
|
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
|
||||||
|
return super().connection_lost(exc)
|
||||||
|
|
||||||
|
def eof_received(self) -> Optional[bool]:
|
||||||
|
self._handle_error(SocketClosedAPIError("EOF received"))
|
||||||
|
return super().eof_received()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
|
self._closed_event.set()
|
||||||
|
if self._transport:
|
||||||
|
self._transport.close()
|
||||||
|
|
||||||
|
|
||||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
"""Frame helper for plaintext API connections."""
|
"""Frame helper for plaintext API connections."""
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the connection."""
|
|
||||||
self._closed_event.set()
|
|
||||||
self._writer.close()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ready(self) -> bool:
|
def ready(self) -> bool:
|
||||||
"""Return if the connection is ready."""
|
"""Return if the connection is ready."""
|
||||||
# Plaintext is always ready
|
return self._connected_event.is_set()
|
||||||
return True
|
|
||||||
|
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
|
||||||
|
"""Complete reading a packet from the buffer."""
|
||||||
|
del self._buffer[: self._pos]
|
||||||
|
self._on_pkt(Packet(type_, data))
|
||||||
|
|
||||||
def write_packet(self, packet: Packet) -> None:
|
def write_packet(self, packet: Packet) -> None:
|
||||||
"""Write a packet to the socket, the caller should not have the lock.
|
"""Write a packet to the socket, the caller should not have the lock.
|
||||||
@ -90,6 +127,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
The entire packet must be written in a single call to write
|
The entire packet must be written in a single call to write
|
||||||
to avoid locking.
|
to avoid locking.
|
||||||
"""
|
"""
|
||||||
|
assert self._transport is not None, "Transport should be set"
|
||||||
data = (
|
data = (
|
||||||
b"\0"
|
b"\0"
|
||||||
+ varuint_to_bytes(len(packet.data))
|
+ varuint_to_bytes(len(packet.data))
|
||||||
@ -99,26 +137,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._writer.write(data)
|
self._transport.write(data)
|
||||||
except (ConnectionResetError, OSError) as err:
|
except (ConnectionResetError, OSError) as err:
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def wait_for_ready(self) -> None:
|
async def perform_handshake(self) -> None:
|
||||||
"""Wait for the connection to be ready."""
|
"""Perform the handshake."""
|
||||||
# No handshake for plaintext
|
await self._connected_event.wait()
|
||||||
|
|
||||||
async def read_packet_with_lock(self) -> Packet:
|
def data_received(self, data: bytes) -> None:
|
||||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
self._buffer += data
|
||||||
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
|
while len(self._buffer) >= 3:
|
||||||
try:
|
|
||||||
# Read preamble, which should always 0x00
|
# Read preamble, which should always 0x00
|
||||||
# Also try to get the length and msg type
|
# Also try to get the length and msg type
|
||||||
# to avoid multiple calls to readexactly
|
# to avoid multiple calls to readexactly
|
||||||
init_bytes = await self._reader.readexactly(3)
|
init_bytes = self._init_read(3)
|
||||||
|
assert init_bytes is not None, "Buffer should have at least 3 bytes"
|
||||||
if init_bytes[0] != 0x00:
|
if init_bytes[0] != 0x00:
|
||||||
if init_bytes[0] == 0x01:
|
if init_bytes[0] == 0x01:
|
||||||
raise RequiresEncryptionAPIError("Connection requires encryption")
|
self._handle_error_and_close(
|
||||||
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
RequiresEncryptionAPIError("Connection requires encryption")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._handle_error_and_close(
|
||||||
|
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if init_bytes[1] & 0x80 == 0x80:
|
if init_bytes[1] & 0x80 == 0x80:
|
||||||
# Length is longer than 1 byte
|
# Length is longer than 1 byte
|
||||||
@ -133,32 +177,35 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
# If the message is long, we need to read the rest of the length
|
# If the message is long, we need to read the rest of the length
|
||||||
while length[-1] & 0x80 == 0x80:
|
while length[-1] & 0x80 == 0x80:
|
||||||
length += await self._reader.readexactly(1)
|
add_length = self._read_exactly(1)
|
||||||
|
if add_length is None:
|
||||||
|
return
|
||||||
|
length += add_length
|
||||||
|
|
||||||
# If the message length was longer than 1 byte, we need to read the
|
# If the message length was longer than 1 byte, we need to read the
|
||||||
# message type
|
# message type
|
||||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||||
msg_type += await self._reader.readexactly(1)
|
add_msg_type = self._read_exactly(1)
|
||||||
|
if add_msg_type is None:
|
||||||
|
return
|
||||||
|
msg_type += add_msg_type
|
||||||
|
|
||||||
length_int = bytes_to_varuint(length)
|
length_int = bytes_to_varuint(bytes(length))
|
||||||
assert length_int is not None
|
assert length_int is not None
|
||||||
msg_type_int = bytes_to_varuint(msg_type)
|
msg_type_int = bytes_to_varuint(bytes(msg_type))
|
||||||
assert msg_type_int is not None
|
assert msg_type_int is not None
|
||||||
|
|
||||||
if length_int == 0:
|
if length_int == 0:
|
||||||
return Packet(type=msg_type_int, data=b"")
|
self._callback_packet(msg_type_int, b"")
|
||||||
|
# If we have more data, continue processing
|
||||||
|
continue
|
||||||
|
|
||||||
data = await self._reader.readexactly(length_int)
|
packet_data = self._read_exactly(length_int)
|
||||||
return Packet(type=msg_type_int, data=data)
|
if packet_data is None:
|
||||||
except SOCKET_ERRORS as err:
|
return
|
||||||
if (
|
|
||||||
isinstance(err, asyncio.IncompleteReadError)
|
self._callback_packet(msg_type_int, packet_data)
|
||||||
and self._closed_event.is_set()
|
# If we have more data, continue processing
|
||||||
):
|
|
||||||
raise SocketClosedAPIError(
|
|
||||||
f"Socket closed while reading data: {err}"
|
|
||||||
) from err
|
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_noise_psk(psk: str) -> bytes:
|
def _decode_noise_psk(psk: str) -> bytes:
|
||||||
@ -176,34 +223,46 @@ def _decode_noise_psk(psk: str) -> bytes:
|
|||||||
return psk_bytes
|
return psk_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseConnectionState(Enum):
|
||||||
|
"""Noise connection state."""
|
||||||
|
|
||||||
|
HELLO = 1
|
||||||
|
HANDSHAKE = 2
|
||||||
|
READY = 3
|
||||||
|
CLOSED = 4
|
||||||
|
|
||||||
|
|
||||||
class APINoiseFrameHelper(APIFrameHelper):
|
class APINoiseFrameHelper(APIFrameHelper):
|
||||||
"""Frame helper for noise encrypted connections."""
|
"""Frame helper for noise encrypted connections."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reader: asyncio.StreamReader,
|
on_pkt: Callable[[Packet], None],
|
||||||
writer: asyncio.StreamWriter,
|
on_error: Callable[[Exception], None],
|
||||||
noise_psk: str,
|
noise_psk: str,
|
||||||
|
expected_name: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
super().__init__(reader, writer)
|
super().__init__(on_pkt, on_error)
|
||||||
self._ready_event = asyncio.Event()
|
self._ready_event = asyncio.Event()
|
||||||
self._proto: Optional[NoiseConnection] = None
|
|
||||||
self._noise_psk = noise_psk
|
self._noise_psk = noise_psk
|
||||||
|
self._expected_name = expected_name
|
||||||
|
self._state = NoiseConnectionState.HELLO
|
||||||
|
self._setup_proto()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ready(self) -> bool:
|
def ready(self) -> bool:
|
||||||
"""Return if the connection is ready."""
|
"""Return if the connection is ready."""
|
||||||
return self._ready_event.is_set()
|
return self._ready_event.is_set()
|
||||||
|
|
||||||
async def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the connection."""
|
"""Close the connection."""
|
||||||
# Make sure we set the ready event if its not already set
|
# Make sure we set the ready event if its not already set
|
||||||
# so that we don't block forever on the ready event if we
|
# so that we don't block forever on the ready event if we
|
||||||
# are waiting for the handshake to complete.
|
# are waiting for the handshake to complete.
|
||||||
self._ready_event.set()
|
self._ready_event.set()
|
||||||
self._closed_event.set()
|
self._state = NoiseConnectionState.CLOSED
|
||||||
self._writer.close()
|
super().close()
|
||||||
|
|
||||||
def _write_frame(self, frame: bytes) -> None:
|
def _write_frame(self, frame: bytes) -> None:
|
||||||
"""Write a packet to the socket, the caller should not have the lock.
|
"""Write a packet to the socket, the caller should not have the lock.
|
||||||
@ -212,6 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
to avoid locking.
|
to avoid locking.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||||
|
assert self._transport is not None, "Transport is not set"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = bytes(
|
header = bytes(
|
||||||
@ -221,39 +281,46 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
len(frame) & 0xFF,
|
len(frame) & 0xFF,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self._writer.write(header + frame)
|
self._transport.write(header + frame)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def _read_frame_with_lock(self) -> bytes:
|
async def perform_handshake(self) -> None:
|
||||||
"""Read a frame from the socket, the caller is responsible for having the lock."""
|
"""Perform the handshake with the server."""
|
||||||
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
|
self._send_hello()
|
||||||
try:
|
try:
|
||||||
header = await self._reader.readexactly(3)
|
async with async_timeout.timeout(60.0):
|
||||||
|
await self._ready_event.wait()
|
||||||
|
except asyncio.TimeoutError as err:
|
||||||
|
raise HandshakeAPIError("Timeout during handshake") from err
|
||||||
|
|
||||||
|
def data_received(self, data: bytes) -> None:
|
||||||
|
self._buffer += data
|
||||||
|
while len(self._buffer) >= 3:
|
||||||
|
header = self._init_read(3)
|
||||||
|
assert header is not None, "Buffer should have at least 3 bytes"
|
||||||
if header[0] != 0x01:
|
if header[0] != 0x01:
|
||||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
self._handle_error_and_close(
|
||||||
|
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
|
)
|
||||||
msg_size = (header[1] << 8) | header[2]
|
msg_size = (header[1] << 8) | header[2]
|
||||||
frame = await self._reader.readexactly(msg_size)
|
frame = self._read_exactly(msg_size)
|
||||||
except SOCKET_ERRORS as err:
|
if frame is None:
|
||||||
if (
|
return
|
||||||
isinstance(err, asyncio.IncompleteReadError)
|
|
||||||
and self._closed_event.is_set()
|
|
||||||
):
|
|
||||||
raise SocketClosedAPIError(
|
|
||||||
f"Socket closed while reading data: {err}"
|
|
||||||
) from err
|
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
|
||||||
|
|
||||||
_LOGGER.debug("Received frame %s", frame.hex())
|
try:
|
||||||
return frame
|
self.STATE_TO_CALLABLE[self._state](self, frame)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
self._handle_error_and_close(err)
|
||||||
|
finally:
|
||||||
|
del self._buffer[: self._pos]
|
||||||
|
|
||||||
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
|
def _send_hello(self) -> None:
|
||||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
"""Send a ClientHello to the server."""
|
||||||
assert self.read_lock.locked(), "_perform_handshake called without lock"
|
|
||||||
self._write_frame(b"") # ClientHello
|
self._write_frame(b"") # ClientHello
|
||||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
|
||||||
|
|
||||||
server_hello = await self._read_frame_with_lock() # ServerHello
|
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||||
|
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
raise HandshakeAPIError("ServerHello is empty")
|
raise HandshakeAPIError("ServerHello is empty")
|
||||||
|
|
||||||
@ -273,76 +340,60 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
if server_name_i != -1:
|
if server_name_i != -1:
|
||||||
# server name found, this extension was added in 2022.2
|
# server name found, this extension was added in 2022.2
|
||||||
server_name = server_hello[1:server_name_i].decode()
|
server_name = server_hello[1:server_name_i].decode()
|
||||||
if expected_name is not None and expected_name != server_name:
|
if self._expected_name is not None and self._expected_name != server_name:
|
||||||
raise BadNameAPIError(
|
raise BadNameAPIError(
|
||||||
f"Server sent a different name '{server_name}'", server_name
|
f"Server sent a different name '{server_name}'", server_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._state = NoiseConnectionState.HANDSHAKE
|
||||||
|
self._send_handshake()
|
||||||
|
|
||||||
|
def _setup_proto(self) -> None:
|
||||||
|
"""Set up the noise protocol."""
|
||||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||||
self._proto.set_as_initiator()
|
self._proto.set_as_initiator()
|
||||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||||
self._proto.set_prologue(prologue)
|
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
|
||||||
self._proto.start_handshake()
|
self._proto.start_handshake()
|
||||||
|
|
||||||
|
def _send_handshake(self) -> None:
|
||||||
|
"""Send the handshake message."""
|
||||||
|
self._write_frame(b"\x00" + self._proto.write_message())
|
||||||
|
|
||||||
|
def _handle_handshake(self, msg: bytearray) -> None:
|
||||||
_LOGGER.debug("Starting handshake...")
|
_LOGGER.debug("Starting handshake...")
|
||||||
do_write = True
|
if msg[0] != 0:
|
||||||
while not self._proto.handshake_finished:
|
explanation = msg[1:].decode()
|
||||||
if do_write:
|
if explanation == "Handshake MAC failure":
|
||||||
msg = self._proto.write_message()
|
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||||
self._write_frame(b"\x00" + msg)
|
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||||
else:
|
self._proto.read_message(msg[1:])
|
||||||
msg = await self._read_frame_with_lock()
|
_LOGGER.debug("Handshake complete")
|
||||||
if not msg:
|
self._state = NoiseConnectionState.READY
|
||||||
raise HandshakeAPIError("Handshake message too short")
|
|
||||||
if msg[0] != 0:
|
|
||||||
explanation = msg[1:].decode()
|
|
||||||
if explanation == "Handshake MAC failure":
|
|
||||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
|
||||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
|
||||||
self._proto.read_message(msg[1:])
|
|
||||||
|
|
||||||
do_write = not do_write
|
|
||||||
|
|
||||||
_LOGGER.debug("Handshake complete!")
|
|
||||||
self._ready_event.set()
|
self._ready_event.set()
|
||||||
|
|
||||||
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
|
||||||
"""Perform the handshake with the server."""
|
|
||||||
# Allow up to 60 seconds for handhsake
|
|
||||||
try:
|
|
||||||
async with self.read_lock, async_timeout.timeout(60.0):
|
|
||||||
await self._perform_handshake(expected_name)
|
|
||||||
except asyncio.TimeoutError as err:
|
|
||||||
raise HandshakeAPIError("Timeout during handshake") from err
|
|
||||||
|
|
||||||
def write_packet(self, packet: Packet) -> None:
|
def write_packet(self, packet: Packet) -> None:
|
||||||
"""Write a packet to the socket."""
|
"""Write a packet to the socket."""
|
||||||
padding = 0
|
self._write_frame(
|
||||||
data = (
|
self._proto.encrypt(
|
||||||
bytes(
|
(
|
||||||
[
|
bytes(
|
||||||
(packet.type >> 8) & 0xFF,
|
[
|
||||||
(packet.type >> 0) & 0xFF,
|
(packet.type >> 8) & 0xFF,
|
||||||
(len(packet.data) >> 8) & 0xFF,
|
(packet.type >> 0) & 0xFF,
|
||||||
(len(packet.data) >> 0) & 0xFF,
|
(len(packet.data) >> 8) & 0xFF,
|
||||||
]
|
(len(packet.data) >> 0) & 0xFF,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ packet.data
|
||||||
|
)
|
||||||
)
|
)
|
||||||
+ packet.data
|
|
||||||
+ b"\x00" * padding
|
|
||||||
)
|
)
|
||||||
assert self._proto is not None
|
|
||||||
frame = self._proto.encrypt(data)
|
|
||||||
self._write_frame(frame)
|
|
||||||
|
|
||||||
async def wait_for_ready(self) -> None:
|
def _handle_frame(self, frame: bytearray) -> None:
|
||||||
"""Wait for the connection to be ready."""
|
"""Handle an incoming frame."""
|
||||||
await self._ready_event.wait()
|
|
||||||
|
|
||||||
async def read_packet_with_lock(self) -> Packet:
|
|
||||||
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
|
||||||
frame = await self._read_frame_with_lock()
|
|
||||||
assert self._proto is not None
|
assert self._proto is not None
|
||||||
msg = self._proto.decrypt(frame)
|
msg = self._proto.decrypt(bytes(frame))
|
||||||
if len(msg) < 4:
|
if len(msg) < 4:
|
||||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||||
pkt_type = (msg[0] << 8) | msg[1]
|
pkt_type = (msg[0] << 8) | msg[1]
|
||||||
@ -350,4 +401,17 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
if data_len + 4 > len(msg):
|
if data_len + 4 > len(msg):
|
||||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||||
data = msg[4 : 4 + data_len]
|
data = msg[4 : 4 + data_len]
|
||||||
return Packet(type=pkt_type, data=data)
|
return self._on_pkt(Packet(pkt_type, data))
|
||||||
|
|
||||||
|
def _handle_closed( # pylint: disable=unused-argument
|
||||||
|
self, frame: bytearray
|
||||||
|
) -> None:
|
||||||
|
"""Handle a closed frame."""
|
||||||
|
self._handle_error(ProtocolAPIError("Connection closed"))
|
||||||
|
|
||||||
|
STATE_TO_CALLABLE = {
|
||||||
|
NoiseConnectionState.HELLO: _handle_hello,
|
||||||
|
NoiseConnectionState.HANDSHAKE: _handle_handshake,
|
||||||
|
NoiseConnectionState.READY: _handle_frame,
|
||||||
|
NoiseConnectionState.CLOSED: _handle_closed,
|
||||||
|
}
|
||||||
|
@ -373,7 +373,7 @@ class APIClient:
|
|||||||
image_stream[msg.key] = data
|
image_stream[msg.key] = data
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
SubscribeStatesRequest(), on_msg, msg_types
|
SubscribeStatesRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -394,7 +394,7 @@ class APIClient:
|
|||||||
if dump_config is not None:
|
if dump_config is not None:
|
||||||
req.dump_config = dump_config
|
req.dump_config = dump_config
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
req, on_msg, (SubscribeLogsResponse,)
|
req, on_msg, (SubscribeLogsResponse,)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -407,7 +407,7 @@ class APIClient:
|
|||||||
on_service_call(HomeassistantServiceCall.from_pb(msg))
|
on_service_call(HomeassistantServiceCall.from_pb(msg))
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
SubscribeHomeassistantServicesRequest(),
|
SubscribeHomeassistantServicesRequest(),
|
||||||
on_msg,
|
on_msg,
|
||||||
(HomeassistantServiceResponse,),
|
(HomeassistantServiceResponse,),
|
||||||
@ -451,7 +451,7 @@ class APIClient:
|
|||||||
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
|
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
|
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -472,7 +472,7 @@ class APIClient:
|
|||||||
on_bluetooth_connections_free_update(resp.free, resp.limit)
|
on_bluetooth_connections_free_update(resp.free, resp.limit)
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
|
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -518,7 +518,7 @@ class APIClient:
|
|||||||
_LOGGER.debug("%s: Using connection version 1", address)
|
_LOGGER.debug("%s: Using connection version 1", address)
|
||||||
request_type = BluetoothDeviceRequestType.CONNECT
|
request_type = BluetoothDeviceRequestType.CONNECT
|
||||||
|
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
BluetoothDeviceRequest(
|
BluetoothDeviceRequest(
|
||||||
address=address,
|
address=address,
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
@ -581,7 +581,7 @@ class APIClient:
|
|||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(
|
self._connection.send_message(
|
||||||
BluetoothDeviceRequest(
|
BluetoothDeviceRequest(
|
||||||
address=address,
|
address=address,
|
||||||
request_type=BluetoothDeviceRequestType.DISCONNECT,
|
request_type=BluetoothDeviceRequestType.DISCONNECT,
|
||||||
@ -661,7 +661,7 @@ class APIClient:
|
|||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._send_bluetooth_message_await_response(
|
await self._send_bluetooth_message_await_response(
|
||||||
@ -709,7 +709,7 @@ class APIClient:
|
|||||||
|
|
||||||
if not wait_for_response:
|
if not wait_for_response:
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._send_bluetooth_message_await_response(
|
await self._send_bluetooth_message_await_response(
|
||||||
@ -762,7 +762,7 @@ class APIClient:
|
|||||||
|
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
await self._connection.send_message(
|
self._connection.send_message(
|
||||||
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False)
|
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -777,7 +777,7 @@ class APIClient:
|
|||||||
on_state_sub(msg.entity_id, msg.attribute)
|
on_state_sub(msg.entity_id, msg.attribute)
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message_callback_response(
|
self._connection.send_message_callback_response(
|
||||||
SubscribeHomeAssistantStatesRequest(),
|
SubscribeHomeAssistantStatesRequest(),
|
||||||
on_msg,
|
on_msg,
|
||||||
(SubscribeHomeAssistantStateResponse,),
|
(SubscribeHomeAssistantStateResponse,),
|
||||||
@ -789,7 +789,7 @@ class APIClient:
|
|||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(
|
self._connection.send_message(
|
||||||
HomeAssistantStateResponse(
|
HomeAssistantStateResponse(
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
state=state,
|
state=state,
|
||||||
@ -829,7 +829,7 @@ class APIClient:
|
|||||||
req.legacy_command = LegacyCoverCommand.CLOSE
|
req.legacy_command = LegacyCoverCommand.CLOSE
|
||||||
req.has_legacy_command = True
|
req.has_legacy_command = True
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def fan_command(
|
async def fan_command(
|
||||||
self,
|
self,
|
||||||
@ -860,7 +860,7 @@ class APIClient:
|
|||||||
req.has_direction = True
|
req.has_direction = True
|
||||||
req.direction = direction
|
req.direction = direction
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def light_command(
|
async def light_command(
|
||||||
self,
|
self,
|
||||||
@ -921,7 +921,7 @@ class APIClient:
|
|||||||
req.has_effect = True
|
req.has_effect = True
|
||||||
req.effect = effect
|
req.effect = effect
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def switch_command(self, key: int, state: bool) -> None:
|
async def switch_command(self, key: int, state: bool) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -930,7 +930,7 @@ class APIClient:
|
|||||||
req.key = key
|
req.key = key
|
||||||
req.state = state
|
req.state = state
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def climate_command(
|
async def climate_command(
|
||||||
self,
|
self,
|
||||||
@ -982,7 +982,7 @@ class APIClient:
|
|||||||
req.has_custom_preset = True
|
req.has_custom_preset = True
|
||||||
req.custom_preset = custom_preset
|
req.custom_preset = custom_preset
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def number_command(self, key: int, state: float) -> None:
|
async def number_command(self, key: int, state: float) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -991,7 +991,7 @@ class APIClient:
|
|||||||
req.key = key
|
req.key = key
|
||||||
req.state = state
|
req.state = state
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def select_command(self, key: int, state: str) -> None:
|
async def select_command(self, key: int, state: str) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -1000,7 +1000,7 @@ class APIClient:
|
|||||||
req.key = key
|
req.key = key
|
||||||
req.state = state
|
req.state = state
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def siren_command(
|
async def siren_command(
|
||||||
self,
|
self,
|
||||||
@ -1027,7 +1027,7 @@ class APIClient:
|
|||||||
req.duration = duration
|
req.duration = duration
|
||||||
req.has_duration = True
|
req.has_duration = True
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def button_command(self, key: int) -> None:
|
async def button_command(self, key: int) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -1035,7 +1035,7 @@ class APIClient:
|
|||||||
req = ButtonCommandRequest()
|
req = ButtonCommandRequest()
|
||||||
req.key = key
|
req.key = key
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def lock_command(
|
async def lock_command(
|
||||||
self,
|
self,
|
||||||
@ -1051,7 +1051,7 @@ class APIClient:
|
|||||||
if code is not None:
|
if code is not None:
|
||||||
req.code = code
|
req.code = code
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def media_player_command(
|
async def media_player_command(
|
||||||
self,
|
self,
|
||||||
@ -1075,7 +1075,7 @@ class APIClient:
|
|||||||
req.media_url = media_url
|
req.media_url = media_url
|
||||||
req.has_media_url = True
|
req.has_media_url = True
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def execute_service(
|
async def execute_service(
|
||||||
self, service: UserService, data: ExecuteServiceDataType
|
self, service: UserService, data: ExecuteServiceDataType
|
||||||
@ -1113,7 +1113,7 @@ class APIClient:
|
|||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
req.args.extend(args)
|
req.args.extend(args)
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def _request_image(
|
async def _request_image(
|
||||||
self, *, single: bool = False, stream: bool = False
|
self, *, single: bool = False, stream: bool = False
|
||||||
@ -1122,7 +1122,7 @@ class APIClient:
|
|||||||
req.single = single
|
req.single = single
|
||||||
req.stream = stream
|
req.stream = stream
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
await self._connection.send_message(req)
|
self._connection.send_message(req)
|
||||||
|
|
||||||
async def request_single_image(self) -> None:
|
async def request_single_image(self) -> None:
|
||||||
await self._request_image(single=True)
|
await self._request_image(single=True)
|
||||||
|
@ -5,7 +5,7 @@ import socket
|
|||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
|
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -40,7 +40,6 @@ from .core import (
|
|||||||
ReadFailedAPIError,
|
ReadFailedAPIError,
|
||||||
ResolveAPIError,
|
ResolveAPIError,
|
||||||
SocketAPIError,
|
SocketAPIError,
|
||||||
SocketClosedAPIError,
|
|
||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
)
|
)
|
||||||
from .model import APIVersion
|
from .model import APIVersion
|
||||||
@ -112,10 +111,6 @@ class APIConnection:
|
|||||||
|
|
||||||
self._ping_stop_event = asyncio.Event()
|
self._ping_stop_event = asyncio.Event()
|
||||||
|
|
||||||
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
|
|
||||||
|
|
||||||
self._process_task: Optional[asyncio.Task[None]] = None
|
|
||||||
|
|
||||||
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
||||||
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
@ -138,27 +133,10 @@ class APIConnection:
|
|||||||
async with self._connect_lock:
|
async with self._connect_lock:
|
||||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||||
|
|
||||||
# Tell the process loop to stop
|
|
||||||
self._to_process.put_nowait(None)
|
|
||||||
|
|
||||||
if self._frame_helper is not None:
|
if self._frame_helper is not None:
|
||||||
await self._frame_helper.close()
|
self._frame_helper.close()
|
||||||
self._frame_helper = None
|
self._frame_helper = None
|
||||||
|
|
||||||
if self._process_task is not None:
|
|
||||||
self._process_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._process_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
_LOGGER.error(
|
|
||||||
"Unexpected exception in process task: %s",
|
|
||||||
err,
|
|
||||||
exc_info=err,
|
|
||||||
)
|
|
||||||
self._process_task = None
|
|
||||||
|
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
self._socket.close()
|
self._socket.close()
|
||||||
self._socket = None
|
self._socket = None
|
||||||
@ -231,24 +209,31 @@ class APIConnection:
|
|||||||
|
|
||||||
async def _connect_init_frame_helper(self) -> None:
|
async def _connect_init_frame_helper(self) -> None:
|
||||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
reader, writer = await asyncio.open_connection(
|
fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper]
|
||||||
sock=self._socket, limit=BUFFER_SIZE
|
loop = asyncio.get_event_loop()
|
||||||
) # Set buffer limit to 1MB
|
|
||||||
|
|
||||||
if self._params.noise_psk is None:
|
if self._params.noise_psk is None:
|
||||||
self._frame_helper = APIPlaintextFrameHelper(reader, writer)
|
_, fh = await loop.create_connection(
|
||||||
else:
|
lambda: APIPlaintextFrameHelper(
|
||||||
fh = self._frame_helper = APINoiseFrameHelper(
|
on_pkt=self._process_packet,
|
||||||
reader, writer, self._params.noise_psk
|
on_error=self._report_fatal_error_and_cleanup_task,
|
||||||
|
),
|
||||||
|
sock=self._socket,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_, fh = await loop.create_connection(
|
||||||
|
lambda: APINoiseFrameHelper(
|
||||||
|
noise_psk=self._params.noise_psk,
|
||||||
|
expected_name=self._params.expected_name,
|
||||||
|
on_pkt=self._process_packet,
|
||||||
|
on_error=self._report_fatal_error_and_cleanup_task,
|
||||||
|
),
|
||||||
|
sock=self._socket,
|
||||||
)
|
)
|
||||||
await fh.perform_handshake(self._params.expected_name)
|
|
||||||
|
|
||||||
|
self._frame_helper = fh
|
||||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||||
|
await fh.perform_handshake()
|
||||||
# Create read loop
|
|
||||||
asyncio.create_task(self._read_loop())
|
|
||||||
# Create process loop
|
|
||||||
self._process_task = asyncio.create_task(self._process_loop())
|
|
||||||
|
|
||||||
async def _connect_hello(self) -> None:
|
async def _connect_hello(self) -> None:
|
||||||
"""Step 4 in connect process: send hello and get api version."""
|
"""Step 4 in connect process: send hello and get api version."""
|
||||||
@ -292,10 +277,7 @@ class APIConnection:
|
|||||||
"""Step 5 in connect process: start the ping loop."""
|
"""Step 5 in connect process: start the ping loop."""
|
||||||
|
|
||||||
async def _keep_alive_loop() -> None:
|
async def _keep_alive_loop() -> None:
|
||||||
while True:
|
while self._is_socket_open:
|
||||||
if not self._is_socket_open:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Wait for keepalive seconds, or ping stop event, whichever happens first
|
# Wait for keepalive seconds, or ping stop event, whichever happens first
|
||||||
try:
|
try:
|
||||||
async with async_timeout.timeout(self._params.keepalive):
|
async with async_timeout.timeout(self._params.keepalive):
|
||||||
@ -404,22 +386,20 @@ class APIConnection:
|
|||||||
def is_authenticated(self) -> bool:
|
def is_authenticated(self) -> bool:
|
||||||
return self.is_connected and self._is_authenticated
|
return self.is_connected and self._is_authenticated
|
||||||
|
|
||||||
async def send_message(self, msg: message.Message) -> None:
|
def send_message(self, msg: message.Message) -> None:
|
||||||
"""Send a protobuf message to the remote."""
|
"""Send a protobuf message to the remote."""
|
||||||
if not self._is_socket_open:
|
if not self._is_socket_open:
|
||||||
raise APIConnectionError("Connection isn't established yet")
|
raise APIConnectionError("Connection isn't established yet")
|
||||||
|
|
||||||
|
frame_helper = self._frame_helper
|
||||||
|
assert frame_helper is not None
|
||||||
|
assert frame_helper.ready, "Frame helper not ready"
|
||||||
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
||||||
if not message_type:
|
if not message_type:
|
||||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||||
encoded = msg.SerializeToString()
|
encoded = msg.SerializeToString()
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||||
|
|
||||||
frame_helper = self._frame_helper
|
|
||||||
assert frame_helper is not None
|
|
||||||
if not frame_helper.ready:
|
|
||||||
await frame_helper.wait_for_ready()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
frame_helper.write_packet(
|
frame_helper.write_packet(
|
||||||
Packet(
|
Packet(
|
||||||
@ -431,7 +411,7 @@ class APIConnection:
|
|||||||
# If writing packet fails, we don't know what state the frames
|
# If writing packet fails, we don't know what state the frames
|
||||||
# are in anymore and we have to close the connection
|
# are in anymore and we have to close the connection
|
||||||
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
||||||
await self._report_fatal_error(err)
|
self._report_fatal_error_and_cleanup_task(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def add_message_callback(
|
def add_message_callback(
|
||||||
@ -454,7 +434,7 @@ class APIConnection:
|
|||||||
for msg_type in msg_types:
|
for msg_type in msg_types:
|
||||||
self._message_handlers[msg_type].remove(on_message)
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
|
|
||||||
async def send_message_callback_response(
|
def send_message_callback_response(
|
||||||
self,
|
self,
|
||||||
send_msg: message.Message,
|
send_msg: message.Message,
|
||||||
on_message: Callable[[Any], None],
|
on_message: Callable[[Any], None],
|
||||||
@ -464,7 +444,7 @@ class APIConnection:
|
|||||||
for msg_type in msg_types:
|
for msg_type in msg_types:
|
||||||
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
self._message_handlers.setdefault(msg_type, []).append(on_message)
|
||||||
try:
|
try:
|
||||||
await self.send_message(send_msg)
|
self.send_message(send_msg)
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
for msg_type in msg_types:
|
for msg_type in msg_types:
|
||||||
self._message_handlers[msg_type].remove(on_message)
|
self._message_handlers[msg_type].remove(on_message)
|
||||||
@ -514,7 +494,7 @@ class APIConnection:
|
|||||||
# the await is cancelled
|
# the await is cancelled
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.send_message(send_msg)
|
self.send_message(send_msg)
|
||||||
async with async_timeout.timeout(timeout):
|
async with async_timeout.timeout(timeout):
|
||||||
await fut
|
await fut
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
@ -545,6 +525,25 @@ class APIConnection:
|
|||||||
|
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
|
def _handle_fatal_error(self, err: Exception) -> None:
|
||||||
|
"""Handle a fatal error that occurred during an operation."""
|
||||||
|
self._connection_state = ConnectionState.CLOSED
|
||||||
|
for handler in self._read_exception_handlers[:]:
|
||||||
|
handler(err)
|
||||||
|
self._read_exception_handlers.clear()
|
||||||
|
|
||||||
|
def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None:
|
||||||
|
"""Handle a fatal error that occurred during an operation.
|
||||||
|
|
||||||
|
This should only be called for errors that mean the connection
|
||||||
|
can no longer be used.
|
||||||
|
|
||||||
|
The connection will be closed, all exception handlers notified.
|
||||||
|
This method does not log the error, the call site should do so.
|
||||||
|
"""
|
||||||
|
self._handle_fatal_error(err)
|
||||||
|
asyncio.create_task(self._cleanup())
|
||||||
|
|
||||||
async def _report_fatal_error(self, err: Exception) -> None:
|
async def _report_fatal_error(self, err: Exception) -> None:
|
||||||
"""Report a fatal error that occurred during an operation.
|
"""Report a fatal error that occurred during an operation.
|
||||||
|
|
||||||
@ -554,106 +553,55 @@ class APIConnection:
|
|||||||
The connection will be closed, all exception handlers notified.
|
The connection will be closed, all exception handlers notified.
|
||||||
This method does not log the error, the call site should do so.
|
This method does not log the error, the call site should do so.
|
||||||
"""
|
"""
|
||||||
self._connection_state = ConnectionState.CLOSED
|
self._handle_fatal_error(err)
|
||||||
for handler in self._read_exception_handlers[:]:
|
|
||||||
handler(err)
|
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
|
|
||||||
async def _process_loop(self) -> None:
|
def _process_packet(self, pkt: Packet) -> None:
|
||||||
to_process = self._to_process
|
"""Process a packet from the socket."""
|
||||||
while True:
|
msg_type_proto = pkt.type
|
||||||
try:
|
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
|
||||||
pkt = await to_process.get()
|
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto)
|
||||||
except RuntimeError:
|
return
|
||||||
break
|
|
||||||
|
|
||||||
if pkt is None:
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
||||||
# Socket closed but task isn't cancelled yet
|
|
||||||
break
|
|
||||||
|
|
||||||
msg_type_proto = pkt.type
|
|
||||||
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s: Skipping message type %s", self.log_name, msg_type_proto
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
|
||||||
try:
|
|
||||||
msg.ParseFromString(pkt.data)
|
|
||||||
except Exception as e:
|
|
||||||
_LOGGER.info(
|
|
||||||
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
|
||||||
self.log_name,
|
|
||||||
pkt.type,
|
|
||||||
pkt.data,
|
|
||||||
e,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
await self._report_fatal_error(
|
|
||||||
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
msg_type = type(msg)
|
|
||||||
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
|
|
||||||
)
|
|
||||||
|
|
||||||
for handler in self._message_handlers.get(msg_type, [])[:]:
|
|
||||||
handler(msg)
|
|
||||||
|
|
||||||
# Pre-check the message type to avoid awaiting
|
|
||||||
# since most messages are not internal messages
|
|
||||||
if msg_type in INTERNAL_MESSAGE_TYPES:
|
|
||||||
await self._handle_internal_messages(msg)
|
|
||||||
|
|
||||||
async def _read_loop(self) -> None:
|
|
||||||
frame_helper = self._frame_helper
|
|
||||||
assert frame_helper is not None
|
|
||||||
await frame_helper.wait_for_ready()
|
|
||||||
to_process = self._to_process
|
|
||||||
try:
|
try:
|
||||||
# Once its ready, we hold the lock for the duration of the
|
msg.ParseFromString(pkt.data)
|
||||||
# connection so we don't have to keep locking/unlocking
|
except Exception as e:
|
||||||
async with frame_helper.read_lock:
|
|
||||||
while True:
|
|
||||||
to_process.put_nowait(await frame_helper.read_packet_with_lock())
|
|
||||||
except SocketClosedAPIError as err:
|
|
||||||
# don't log with info, if closed the site that closed the connection should log
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s: Socket closed, stopping read loop",
|
|
||||||
self.log_name,
|
|
||||||
)
|
|
||||||
await self._report_fatal_error(err)
|
|
||||||
except APIConnectionError as err:
|
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"%s: Error while reading incoming messages: %s",
|
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
err,
|
pkt.type,
|
||||||
)
|
pkt.data,
|
||||||
await self._report_fatal_error(err)
|
e,
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
_LOGGER.warning(
|
|
||||||
"%s: Unexpected error while reading incoming messages: %s",
|
|
||||||
self.log_name,
|
|
||||||
err,
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
await self._report_fatal_error(err)
|
self._report_fatal_error_and_cleanup_task(
|
||||||
|
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
msg_type = type(msg)
|
||||||
|
|
||||||
|
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, msg_type, msg)
|
||||||
|
|
||||||
|
for handler in self._message_handlers.get(msg_type, [])[:]:
|
||||||
|
handler(msg)
|
||||||
|
|
||||||
|
# Pre-check the message type to avoid awaiting
|
||||||
|
# since most messages are not internal messages
|
||||||
|
if msg_type not in INTERNAL_MESSAGE_TYPES:
|
||||||
|
return
|
||||||
|
|
||||||
async def _handle_internal_messages(self, msg: Any) -> None:
|
|
||||||
if isinstance(msg, DisconnectRequest):
|
if isinstance(msg, DisconnectRequest):
|
||||||
await self.send_message(DisconnectResponse())
|
self.send_message(DisconnectResponse())
|
||||||
self._connection_state = ConnectionState.CLOSED
|
self._connection_state = ConnectionState.CLOSED
|
||||||
await self._cleanup()
|
asyncio.create_task(self._cleanup())
|
||||||
elif isinstance(msg, PingRequest):
|
elif isinstance(msg, PingRequest):
|
||||||
await self.send_message(PingResponse())
|
self.send_message(PingResponse())
|
||||||
elif isinstance(msg, GetTimeRequest):
|
elif isinstance(msg, GetTimeRequest):
|
||||||
resp = GetTimeResponse()
|
resp = GetTimeResponse()
|
||||||
resp.epoch_seconds = int(time.time())
|
resp.epoch_seconds = int(time.time())
|
||||||
await self.send_message(resp)
|
self.send_message(resp)
|
||||||
|
|
||||||
async def _ping(self) -> None:
|
async def _ping(self) -> None:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import math
|
import math
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1024)
|
||||||
def varuint_to_bytes(value: int) -> bytes:
|
def varuint_to_bytes(value: int) -> bytes:
|
||||||
if value <= 0x7F:
|
if value <= 0x7F:
|
||||||
return bytes([value])
|
return bytes([value])
|
||||||
@ -18,6 +20,7 @@ def varuint_to_bytes(value: int) -> bytes:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1024)
|
||||||
def bytes_to_varuint(value: bytes) -> Optional[int]:
|
def bytes_to_varuint(value: bytes) -> Optional[int]:
|
||||||
result = 0
|
result = 0
|
||||||
bitpos = 0
|
bitpos = 0
|
||||||
|
@ -3,7 +3,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
|
||||||
from aioesphomeapi.util import varuint_to_bytes
|
from aioesphomeapi.util import varuint_to_bytes
|
||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
@ -46,17 +46,20 @@ PREAMBLE = b"\x00"
|
|||||||
)
|
)
|
||||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||||
|
|
||||||
stream_reader = asyncio.StreamReader()
|
|
||||||
stream_writer = MagicMock()
|
|
||||||
|
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
packets = []
|
||||||
|
|
||||||
stream_reader.feed_data(in_bytes)
|
def _packet(pkt: Packet):
|
||||||
|
packets.append(pkt)
|
||||||
|
|
||||||
helper = APIPlaintextFrameHelper(stream_reader, stream_writer)
|
def _on_error(exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
async with helper.read_lock:
|
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
|
||||||
pkt = await helper.read_packet_with_lock()
|
|
||||||
|
helper.data_received(in_bytes)
|
||||||
|
|
||||||
|
pkt = packets.pop()
|
||||||
|
|
||||||
assert pkt.type == pkt_type
|
assert pkt.type == pkt_type
|
||||||
assert pkt.data == pkt_data
|
assert pkt.data == pkt_data
|
||||||
|
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
|
|||||||
def patch_response_callback(client: APIClient):
|
def patch_response_callback(client: APIClient):
|
||||||
on_message = None
|
on_message = None
|
||||||
|
|
||||||
async def patched(req, callback, msg_types):
|
def patched(req, callback, msg_types):
|
||||||
nonlocal on_message
|
nonlocal on_message
|
||||||
on_message = callback
|
on_message = callback
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import socket
|
|||||||
import pytest
|
import pytest
|
||||||
from mock import AsyncMock, MagicMock, Mock, patch
|
from mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
|
||||||
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
||||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||||
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
|
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
|
||||||
@ -50,13 +51,26 @@ def socket_socket():
|
|||||||
yield func
|
yield func
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_protocol():
|
||||||
|
def _on_packet(pkt: Packet):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_error(exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
|
||||||
|
protocol._connected_event.set()
|
||||||
|
protocol._transport = MagicMock()
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||||
with patch.object(event_loop, "sock_connect"), patch(
|
loop = asyncio.get_event_loop()
|
||||||
"asyncio.open_connection", return_value=(None, None)
|
protocol = _get_mock_protocol()
|
||||||
), patch.object(conn, "_read_loop"), patch.object(
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
conn, "_connect_start_ping"
|
loop, "create_connection", return_value=(MagicMock(), protocol)
|
||||||
), patch.object(
|
), patch.object(conn, "_connect_start_ping"), patch.object(
|
||||||
conn, "send_message_await_response", return_value=HelloResponse()
|
conn, "send_message_await_response", return_value=HelloResponse()
|
||||||
):
|
):
|
||||||
await conn.connect(login=False)
|
await conn.connect(login=False)
|
||||||
@ -66,14 +80,15 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_requires_encryption_propagates(conn):
|
async def test_requires_encryption_propagates(conn):
|
||||||
with patch("asyncio.open_connection") as openc:
|
loop = asyncio.get_event_loop()
|
||||||
reader = MagicMock()
|
protocol = _get_mock_protocol()
|
||||||
writer = MagicMock()
|
with patch.object(loop, "create_connection") as create_connection, patch.object(
|
||||||
openc.return_value = (reader, writer)
|
protocol, "perform_handshake"
|
||||||
writer.drain = AsyncMock()
|
):
|
||||||
reader.readexactly = AsyncMock()
|
create_connection.return_value = (MagicMock(), protocol)
|
||||||
reader.readexactly.return_value = b"\x01"
|
|
||||||
|
|
||||||
await conn._connect_init_frame_helper()
|
await conn._connect_init_frame_helper()
|
||||||
|
|
||||||
with pytest.raises(RequiresEncryptionAPIError):
|
with pytest.raises(RequiresEncryptionAPIError):
|
||||||
|
protocol.data_received(b"\x01\x00\x00")
|
||||||
await conn._connect_hello()
|
await conn._connect_hello()
|
||||||
|
Loading…
Reference in New Issue
Block a user