Simplify connection flow with an asyncio.Protocol (#352)

This commit is contained in:
J. Nick Koston 2023-01-05 18:24:10 -10:00 committed by GitHub
parent 049dc8bb56
commit 2886d361f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 348 additions and 315 deletions

View File

@ -1,9 +1,10 @@
import asyncio
import base64
import logging
from abc import ABC, abstractmethod, abstractproperty
from abc import abstractmethod, abstractproperty
from dataclasses import dataclass
from typing import Optional
from enum import Enum
from typing import Callable, Optional, Union, cast
import async_timeout
from noise.connection import NoiseConnection # type: ignore
@ -32,57 +33,93 @@ SOCKET_ERRORS = (
@dataclass
class Packet:
type: int
data: bytes
data: Union[bytes, bytearray]
class APIFrameHelper(ABC):
class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol."""
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
on_pkt: Callable[[Packet], None],
on_error: Callable[[Exception], None],
) -> None:
"""Initialize the API frame helper."""
self._reader = reader
self._writer = writer
self._on_pkt = on_pkt
self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self.read_lock = asyncio.Lock()
self._closed_event = asyncio.Event()
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._pos = 0
@abstractproperty # pylint: disable=deprecated-decorator
def ready(self) -> bool:
"""Return if the connection is ready."""
def _init_read(self, length: int) -> Optional[bytearray]:
"""Start reading a packet from the buffer."""
self._pos = 0
return self._read_exactly(length)
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
if len(self._buffer) < new_pos:
return None
self._pos = new_pos
return self._buffer[original_pos:new_pos]
@abstractmethod
async def close(self) -> None:
"""Close the connection."""
async def perform_handshake(self) -> None:
"""Perform the handshake."""
@abstractmethod
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
@abstractmethod
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._connected_event.set()
@abstractmethod
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)
self.close()
def _handle_error(self, exc: Exception) -> None:
self._closed_event.set()
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]:
self._handle_error(SocketClosedAPIError("EOF received"))
return super().eof_received()
def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
if self._transport:
self._transport.close()
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
async def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
self._writer.close()
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
# Plaintext is always ready
return True
return self._connected_event.is_set()
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
self._on_pkt(Packet(type_, data))
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket, the caller should not have the lock.
@ -90,6 +127,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
The entire packet must be written in a single call to write
to avoid locking.
"""
assert self._transport is not None, "Transport should be set"
data = (
b"\0"
+ varuint_to_bytes(len(packet.data))
@ -99,26 +137,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
_LOGGER.debug("Sending plaintext frame %s", data.hex())
try:
self._writer.write(data)
self._transport.write(data)
except (ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
# No handshake for plaintext
async def perform_handshake(self) -> None:
"""Perform the handshake."""
await self._connected_event.wait()
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
try:
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to readexactly
init_bytes = await self._reader.readexactly(3)
init_bytes = self._init_read(3)
assert init_bytes is not None, "Buffer should have at least 3 bytes"
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption")
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
self._handle_error_and_close(
RequiresEncryptionAPIError("Connection requires encryption")
)
return
self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
)
return
if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte
@ -133,32 +177,35 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
length += await self._reader.readexactly(1)
add_length = self._read_exactly(1)
if add_length is None:
return
length += add_length
# If the message length was longer than 1 byte, we need to read the
# message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
add_msg_type = self._read_exactly(1)
if add_msg_type is None:
return
msg_type += add_msg_type
length_int = bytes_to_varuint(length)
length_int = bytes_to_varuint(bytes(length))
assert length_int is not None
msg_type_int = bytes_to_varuint(msg_type)
msg_type_int = bytes_to_varuint(bytes(msg_type))
assert msg_type_int is not None
if length_int == 0:
return Packet(type=msg_type_int, data=b"")
self._callback_packet(msg_type_int, b"")
# If we have more data, continue processing
continue
data = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=data)
except SOCKET_ERRORS as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
packet_data = self._read_exactly(length_int)
if packet_data is None:
return
self._callback_packet(msg_type_int, packet_data)
# If we have more data, continue processing
def _decode_noise_psk(psk: str) -> bytes:
@ -176,34 +223,46 @@ def _decode_noise_psk(psk: str) -> bytes:
return psk_bytes
class NoiseConnectionState(Enum):
"""Noise connection state."""
HELLO = 1
HANDSHAKE = 2
READY = 3
CLOSED = 4
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
on_pkt: Callable[[Packet], None],
on_error: Callable[[Exception], None],
noise_psk: str,
expected_name: Optional[str],
) -> None:
"""Initialize the API frame helper."""
super().__init__(reader, writer)
super().__init__(on_pkt, on_error)
self._ready_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._setup_proto()
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
return self._ready_event.is_set()
async def close(self) -> None:
def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._ready_event.set()
self._closed_event.set()
self._writer.close()
self._state = NoiseConnectionState.CLOSED
super().close()
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
@ -212,6 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
to avoid locking.
"""
_LOGGER.debug("Sending frame %s", frame.hex())
assert self._transport is not None, "Transport is not set"
try:
header = bytes(
@ -221,39 +281,46 @@ class APINoiseFrameHelper(APIFrameHelper):
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
self._transport.write(header + frame)
except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame_with_lock(self) -> bytes:
"""Read a frame from the socket, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
async def perform_handshake(self) -> None:
"""Perform the handshake with the server."""
self._send_hello()
try:
header = await self._reader.readexactly(3)
async with async_timeout.timeout(60.0):
await self._ready_event.wait()
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
header = self._init_read(3)
assert header is not None, "Buffer should have at least 3 bytes"
if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
)
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except SOCKET_ERRORS as err:
if (
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
frame = self._read_exactly(msg_size)
if frame is None:
return
_LOGGER.debug("Received frame %s", frame.hex())
return frame
try:
self.STATE_TO_CALLABLE[self._state](self, frame)
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
del self._buffer[: self._pos]
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
assert self.read_lock.locked(), "_perform_handshake called without lock"
def _send_hello(self) -> None:
"""Send a ClientHello to the server."""
self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_with_lock() # ServerHello
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
@ -273,76 +340,60 @@ class APINoiseFrameHelper(APIFrameHelper):
if server_name_i != -1:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
if expected_name is not None and expected_name != server_name:
if self._expected_name is not None and self._expected_name != server_name:
raise BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
)
self._state = NoiseConnectionState.HANDSHAKE
self._send_handshake()
def _setup_proto(self) -> None:
"""Set up the noise protocol."""
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
self._proto.set_prologue(prologue)
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
self._proto.start_handshake()
def _send_handshake(self) -> None:
"""Send the handshake message."""
self._write_frame(b"\x00" + self._proto.write_message())
def _handle_handshake(self, msg: bytearray) -> None:
_LOGGER.debug("Starting handshake...")
do_write = True
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
self._write_frame(b"\x00" + msg)
else:
msg = await self._read_frame_with_lock()
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
_LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY
self._ready_event.set()
async def perform_handshake(self, expected_name: Optional[str]) -> None:
"""Perform the handshake with the server."""
# Allow up to 60 seconds for handhsake
try:
async with self.read_lock, async_timeout.timeout(60.0):
await self._perform_handshake(expected_name)
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket."""
padding = 0
data = (
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
self._write_frame(
self._proto.encrypt(
(
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
)
)
+ packet.data
+ b"\x00" * padding
)
assert self._proto is not None
frame = self._proto.encrypt(data)
self._write_frame(frame)
async def wait_for_ready(self) -> None:
"""Wait for the connection to be ready."""
await self._ready_event.wait()
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
frame = await self._read_frame_with_lock()
def _handle_frame(self, frame: bytearray) -> None:
"""Handle an incoming frame."""
assert self._proto is not None
msg = self._proto.decrypt(frame)
msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
@ -350,4 +401,17 @@ class APINoiseFrameHelper(APIFrameHelper):
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)
return self._on_pkt(Packet(pkt_type, data))
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray
) -> None:
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError("Connection closed"))
STATE_TO_CALLABLE = {
NoiseConnectionState.HELLO: _handle_hello,
NoiseConnectionState.HANDSHAKE: _handle_handshake,
NoiseConnectionState.READY: _handle_frame,
NoiseConnectionState.CLOSED: _handle_closed,
}

View File

@ -373,7 +373,7 @@ class APIClient:
image_stream[msg.key] = data
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg, msg_types
)
@ -394,7 +394,7 @@ class APIClient:
if dump_config is not None:
req.dump_config = dump_config
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
req, on_msg, (SubscribeLogsResponse,)
)
@ -407,7 +407,7 @@ class APIClient:
on_service_call(HomeassistantServiceCall.from_pb(msg))
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
SubscribeHomeassistantServicesRequest(),
on_msg,
(HomeassistantServiceResponse,),
@ -451,7 +451,7 @@ class APIClient:
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
)
@ -472,7 +472,7 @@ class APIClient:
on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
)
@ -518,7 +518,7 @@ class APIClient:
_LOGGER.debug("%s: Using connection version 1", address)
request_type = BluetoothDeviceRequestType.CONNECT
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
BluetoothDeviceRequest(
address=address,
request_type=request_type,
@ -581,7 +581,7 @@ class APIClient:
self._check_authenticated()
assert self._connection is not None
await self._connection.send_message(
self._connection.send_message(
BluetoothDeviceRequest(
address=address,
request_type=BluetoothDeviceRequestType.DISCONNECT,
@ -661,7 +661,7 @@ class APIClient:
if not response:
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
return
await self._send_bluetooth_message_await_response(
@ -709,7 +709,7 @@ class APIClient:
if not wait_for_response:
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
return
await self._send_bluetooth_message_await_response(
@ -762,7 +762,7 @@ class APIClient:
self._check_authenticated()
await self._connection.send_message(
self._connection.send_message(
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False)
)
@ -777,7 +777,7 @@ class APIClient:
on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None
await self._connection.send_message_callback_response(
self._connection.send_message_callback_response(
SubscribeHomeAssistantStatesRequest(),
on_msg,
(SubscribeHomeAssistantStateResponse,),
@ -789,7 +789,7 @@ class APIClient:
self._check_authenticated()
assert self._connection is not None
await self._connection.send_message(
self._connection.send_message(
HomeAssistantStateResponse(
entity_id=entity_id,
state=state,
@ -829,7 +829,7 @@ class APIClient:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def fan_command(
self,
@ -860,7 +860,7 @@ class APIClient:
req.has_direction = True
req.direction = direction
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def light_command(
self,
@ -921,7 +921,7 @@ class APIClient:
req.has_effect = True
req.effect = effect
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def switch_command(self, key: int, state: bool) -> None:
self._check_authenticated()
@ -930,7 +930,7 @@ class APIClient:
req.key = key
req.state = state
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def climate_command(
self,
@ -982,7 +982,7 @@ class APIClient:
req.has_custom_preset = True
req.custom_preset = custom_preset
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def number_command(self, key: int, state: float) -> None:
self._check_authenticated()
@ -991,7 +991,7 @@ class APIClient:
req.key = key
req.state = state
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def select_command(self, key: int, state: str) -> None:
self._check_authenticated()
@ -1000,7 +1000,7 @@ class APIClient:
req.key = key
req.state = state
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def siren_command(
self,
@ -1027,7 +1027,7 @@ class APIClient:
req.duration = duration
req.has_duration = True
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def button_command(self, key: int) -> None:
self._check_authenticated()
@ -1035,7 +1035,7 @@ class APIClient:
req = ButtonCommandRequest()
req.key = key
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def lock_command(
self,
@ -1051,7 +1051,7 @@ class APIClient:
if code is not None:
req.code = code
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def media_player_command(
self,
@ -1075,7 +1075,7 @@ class APIClient:
req.media_url = media_url
req.has_media_url = True
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def execute_service(
self, service: UserService, data: ExecuteServiceDataType
@ -1113,7 +1113,7 @@ class APIClient:
# pylint: disable=no-member
req.args.extend(args)
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def _request_image(
self, *, single: bool = False, stream: bool = False
@ -1122,7 +1122,7 @@ class APIClient:
req.single = single
req.stream = stream
assert self._connection is not None
await self._connection.send_message(req)
self._connection.send_message(req)
async def request_single_image(self) -> None:
await self._request_image(single=True)

View File

@ -5,7 +5,7 @@ import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union
import async_timeout
from google.protobuf import message
@ -40,7 +40,6 @@ from .core import (
ReadFailedAPIError,
ResolveAPIError,
SocketAPIError,
SocketClosedAPIError,
TimeoutAPIError,
)
from .model import APIVersion
@ -112,10 +111,6 @@ class APIConnection:
self._ping_stop_event = asyncio.Event()
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
self._process_task: Optional[asyncio.Task[None]] = None
self._connect_lock: asyncio.Lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None
@ -138,27 +133,10 @@ class APIConnection:
async with self._connect_lock:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
# Tell the process loop to stop
self._to_process.put_nowait(None)
if self._frame_helper is not None:
await self._frame_helper.close()
self._frame_helper.close()
self._frame_helper = None
if self._process_task is not None:
self._process_task.cancel()
try:
await self._process_task
except asyncio.CancelledError:
pass
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(
"Unexpected exception in process task: %s",
err,
exc_info=err,
)
self._process_task = None
if self._socket is not None:
self._socket.close()
self._socket = None
@ -231,24 +209,31 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
reader, writer = await asyncio.open_connection(
sock=self._socket, limit=BUFFER_SIZE
) # Set buffer limit to 1MB
fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper]
loop = asyncio.get_event_loop()
if self._params.noise_psk is None:
self._frame_helper = APIPlaintextFrameHelper(reader, writer)
else:
fh = self._frame_helper = APINoiseFrameHelper(
reader, writer, self._params.noise_psk
_, fh = await loop.create_connection(
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
on_error=self._report_fatal_error_and_cleanup_task,
),
sock=self._socket,
)
else:
_, fh = await loop.create_connection(
lambda: APINoiseFrameHelper(
noise_psk=self._params.noise_psk,
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error_and_cleanup_task,
),
sock=self._socket,
)
await fh.perform_handshake(self._params.expected_name)
self._frame_helper = fh
self._connection_state = ConnectionState.SOCKET_OPENED
# Create read loop
asyncio.create_task(self._read_loop())
# Create process loop
self._process_task = asyncio.create_task(self._process_loop())
await fh.perform_handshake()
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
@ -292,10 +277,7 @@ class APIConnection:
"""Step 5 in connect process: start the ping loop."""
async def _keep_alive_loop() -> None:
while True:
if not self._is_socket_open:
return
while self._is_socket_open:
# Wait for keepalive seconds, or ping stop event, whichever happens first
try:
async with async_timeout.timeout(self._params.keepalive):
@ -404,22 +386,20 @@ class APIConnection:
def is_authenticated(self) -> bool:
return self.is_connected and self._is_authenticated
async def send_message(self, msg: message.Message) -> None:
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if not self._is_socket_open:
raise APIConnectionError("Connection isn't established yet")
frame_helper = self._frame_helper
assert frame_helper is not None
assert frame_helper.ready, "Frame helper not ready"
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
if not message_type:
raise ValueError(f"Message type id not found for type {type(msg)}")
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
frame_helper = self._frame_helper
assert frame_helper is not None
if not frame_helper.ready:
await frame_helper.wait_for_ready()
try:
frame_helper.write_packet(
Packet(
@ -431,7 +411,7 @@ class APIConnection:
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
await self._report_fatal_error(err)
self._report_fatal_error_and_cleanup_task(err)
raise
def add_message_callback(
@ -454,7 +434,7 @@ class APIConnection:
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
async def send_message_callback_response(
def send_message_callback_response(
self,
send_msg: message.Message,
on_message: Callable[[Any], None],
@ -464,7 +444,7 @@ class APIConnection:
for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message)
try:
await self.send_message(send_msg)
self.send_message(send_msg)
except (asyncio.CancelledError, Exception):
for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message)
@ -514,7 +494,7 @@ class APIConnection:
# the await is cancelled
try:
await self.send_message(send_msg)
self.send_message(send_msg)
async with async_timeout.timeout(timeout):
await fut
except asyncio.TimeoutError as err:
@ -545,6 +525,25 @@ class APIConnection:
return res[0]
def _handle_fatal_error(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation."""
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._read_exception_handlers.clear()
def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
can no longer be used.
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
self._handle_fatal_error(err)
asyncio.create_task(self._cleanup())
async def _report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occurred during an operation.
@ -554,106 +553,55 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._handle_fatal_error(err)
await self._cleanup()
async def _process_loop(self) -> None:
to_process = self._to_process
while True:
try:
pkt = await to_process.get()
except RuntimeError:
break
def _process_packet(self, pkt: Packet) -> None:
"""Process a packet from the socket."""
msg_type_proto = pkt.type
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto)
return
if pkt is None:
# Socket closed but task isn't cancelled yet
break
msg_type_proto = pkt.type
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto
)
continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try:
msg.ParseFromString(pkt.data)
except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name,
pkt.type,
pkt.data,
e,
exc_info=True,
)
await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
raise
msg_type = type(msg)
_LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
)
for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type in INTERNAL_MESSAGE_TYPES:
await self._handle_internal_messages(msg)
async def _read_loop(self) -> None:
frame_helper = self._frame_helper
assert frame_helper is not None
await frame_helper.wait_for_ready()
to_process = self._to_process
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try:
# Once its ready, we hold the lock for the duration of the
# connection so we don't have to keep locking/unlocking
async with frame_helper.read_lock:
while True:
to_process.put_nowait(await frame_helper.read_packet_with_lock())
except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log
_LOGGER.debug(
"%s: Socket closed, stopping read loop",
self.log_name,
)
await self._report_fatal_error(err)
except APIConnectionError as err:
msg.ParseFromString(pkt.data)
except Exception as e:
_LOGGER.info(
"%s: Error while reading incoming messages: %s",
"%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name,
err,
)
await self._report_fatal_error(err)
except Exception as err: # pylint: disable=broad-except
_LOGGER.warning(
"%s: Unexpected error while reading incoming messages: %s",
self.log_name,
err,
pkt.type,
pkt.data,
e,
exc_info=True,
)
await self._report_fatal_error(err)
self._report_fatal_error_and_cleanup_task(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
raise
msg_type = type(msg)
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, msg_type, msg)
for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type not in INTERNAL_MESSAGE_TYPES:
return
async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, DisconnectRequest):
await self.send_message(DisconnectResponse())
self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED
await self._cleanup()
asyncio.create_task(self._cleanup())
elif isinstance(msg, PingRequest):
await self.send_message(PingResponse())
self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest):
resp = GetTimeResponse()
resp.epoch_seconds = int(time.time())
await self.send_message(resp)
self.send_message(resp)
async def _ping(self) -> None:
self._check_connected()

View File

@ -1,7 +1,9 @@
import math
from functools import lru_cache
from typing import Optional
@lru_cache(maxsize=1024)
def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F:
return bytes([value])
@ -18,6 +20,7 @@ def varuint_to_bytes(value: int) -> bytes:
return ret
@lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> Optional[int]:
result = 0
bitpos = 0

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock
import pytest
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00"
@ -46,17 +46,20 @@ PREAMBLE = b"\x00"
)
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
stream_reader = asyncio.StreamReader()
stream_writer = MagicMock()
for _ in range(5):
packets = []
stream_reader.feed_data(in_bytes)
def _packet(pkt: Packet):
packets.append(pkt)
helper = APIPlaintextFrameHelper(stream_reader, stream_writer)
def _on_error(exc: Exception):
raise exc
async with helper.read_lock:
pkt = await helper.read_packet_with_lock()
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
helper.data_received(in_bytes)
pkt = packets.pop()
assert pkt.type == pkt_type
assert pkt.data == pkt_data

View File

@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
def patch_response_callback(client: APIClient):
on_message = None
async def patched(req, callback, msg_types):
def patched(req, callback, msg_types):
nonlocal on_message
on_message = callback

View File

@ -4,6 +4,7 @@ import socket
import pytest
from mock import AsyncMock, MagicMock, Mock, patch
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
@ -50,13 +51,26 @@ def socket_socket():
yield func
def _get_mock_protocol():
def _on_packet(pkt: Packet):
pass
def _on_error(exc: Exception):
raise exc
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
protocol._connected_event.set()
protocol._transport = MagicMock()
return protocol
@pytest.mark.asyncio
async def test_connect(conn, resolve_host, socket_socket, event_loop):
with patch.object(event_loop, "sock_connect"), patch(
"asyncio.open_connection", return_value=(None, None)
), patch.object(conn, "_read_loop"), patch.object(
conn, "_connect_start_ping"
), patch.object(
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", return_value=(MagicMock(), protocol)
), patch.object(conn, "_connect_start_ping"), patch.object(
conn, "send_message_await_response", return_value=HelloResponse()
):
await conn.connect(login=False)
@ -66,14 +80,15 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
@pytest.mark.asyncio
async def test_requires_encryption_propagates(conn):
with patch("asyncio.open_connection") as openc:
reader = MagicMock()
writer = MagicMock()
openc.return_value = (reader, writer)
writer.drain = AsyncMock()
reader.readexactly = AsyncMock()
reader.readexactly.return_value = b"\x01"
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol()
with patch.object(loop, "create_connection") as create_connection, patch.object(
protocol, "perform_handshake"
):
create_connection.return_value = (MagicMock(), protocol)
await conn._connect_init_frame_helper()
with pytest.raises(RequiresEncryptionAPIError):
protocol.data_received(b"\x01\x00\x00")
await conn._connect_hello()