Simplify connection flow with an asyncio.Protocol (#352)

This commit is contained in:
J. Nick Koston 2023-01-05 18:24:10 -10:00 committed by GitHub
parent 049dc8bb56
commit 2886d361f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 348 additions and 315 deletions

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
import base64 import base64
import logging import logging
from abc import ABC, abstractmethod, abstractproperty from abc import abstractmethod, abstractproperty
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from enum import Enum
from typing import Callable, Optional, Union, cast
import async_timeout import async_timeout
from noise.connection import NoiseConnection # type: ignore from noise.connection import NoiseConnection # type: ignore
@ -32,57 +33,93 @@ SOCKET_ERRORS = (
@dataclass @dataclass
class Packet: class Packet:
type: int type: int
data: bytes data: Union[bytes, bytearray]
class APIFrameHelper(ABC): class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol.""" """Helper class to handle the API frame protocol."""
def __init__( def __init__(
self, self,
reader: asyncio.StreamReader, on_pkt: Callable[[Packet], None],
writer: asyncio.StreamWriter, on_error: Callable[[Exception], None],
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
self._reader = reader self._on_pkt = on_pkt
self._writer = writer self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self.read_lock = asyncio.Lock() self.read_lock = asyncio.Lock()
self._closed_event = asyncio.Event() self._closed_event = asyncio.Event()
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._pos = 0
@abstractproperty # pylint: disable=deprecated-decorator @abstractproperty # pylint: disable=deprecated-decorator
def ready(self) -> bool: def ready(self) -> bool:
"""Return if the connection is ready.""" """Return if the connection is ready."""
def _init_read(self, length: int) -> Optional[bytearray]:
"""Start reading a packet from the buffer."""
self._pos = 0
return self._read_exactly(length)
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
if len(self._buffer) < new_pos:
return None
self._pos = new_pos
return self._buffer[original_pos:new_pos]
@abstractmethod @abstractmethod
async def close(self) -> None: async def perform_handshake(self) -> None:
"""Close the connection.""" """Perform the handshake."""
@abstractmethod @abstractmethod
def write_packet(self, packet: Packet) -> None: def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket.""" """Write a packet to the socket."""
@abstractmethod def connection_made(self, transport: asyncio.BaseTransport) -> None:
async def read_packet_with_lock(self) -> Packet: """Handle a new connection."""
"""Read a packet from the socket, the caller is responsible for having the lock.""" self._transport = cast(asyncio.Transport, transport)
self._connected_event.set()
@abstractmethod def _handle_error_and_close(self, exc: Exception) -> None:
async def wait_for_ready(self) -> None: self._handle_error(exc)
"""Wait for the connection to be ready.""" self.close()
def _handle_error(self, exc: Exception) -> None:
self._closed_event.set()
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]:
self._handle_error(SocketClosedAPIError("EOF received"))
return super().eof_received()
def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
if self._transport:
self._transport.close()
class APIPlaintextFrameHelper(APIFrameHelper): class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections.""" """Frame helper for plaintext API connections."""
async def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
self._writer.close()
@property @property
def ready(self) -> bool: def ready(self) -> bool:
"""Return if the connection is ready.""" """Return if the connection is ready."""
# Plaintext is always ready return self._connected_event.is_set()
return True
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
self._on_pkt(Packet(type_, data))
def write_packet(self, packet: Packet) -> None: def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket, the caller should not have the lock. """Write a packet to the socket, the caller should not have the lock.
@ -90,6 +127,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
The entire packet must be written in a single call to write The entire packet must be written in a single call to write
to avoid locking. to avoid locking.
""" """
assert self._transport is not None, "Transport should be set"
data = ( data = (
b"\0" b"\0"
+ varuint_to_bytes(len(packet.data)) + varuint_to_bytes(len(packet.data))
@ -99,26 +137,32 @@ class APIPlaintextFrameHelper(APIFrameHelper):
_LOGGER.debug("Sending plaintext frame %s", data.hex()) _LOGGER.debug("Sending plaintext frame %s", data.hex())
try: try:
self._writer.write(data) self._transport.write(data)
except (ConnectionResetError, OSError) as err: except (ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def wait_for_ready(self) -> None: async def perform_handshake(self) -> None:
"""Wait for the connection to be ready.""" """Perform the handshake."""
# No handshake for plaintext await self._connected_event.wait()
async def read_packet_with_lock(self) -> Packet: def data_received(self, data: bytes) -> None:
"""Read a packet from the socket, the caller is responsible for having the lock.""" self._buffer += data
assert self.read_lock.locked(), "read_packet_with_lock called without lock" while len(self._buffer) >= 3:
try:
# Read preamble, which should always 0x00 # Read preamble, which should always 0x00
# Also try to get the length and msg type # Also try to get the length and msg type
# to avoid multiple calls to readexactly # to avoid multiple calls to readexactly
init_bytes = await self._reader.readexactly(3) init_bytes = self._init_read(3)
assert init_bytes is not None, "Buffer should have at least 3 bytes"
if init_bytes[0] != 0x00: if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01: if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption") self._handle_error_and_close(
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") RequiresEncryptionAPIError("Connection requires encryption")
)
return
self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
)
return
if init_bytes[1] & 0x80 == 0x80: if init_bytes[1] & 0x80 == 0x80:
# Length is longer than 1 byte # Length is longer than 1 byte
@ -133,32 +177,35 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# If the message is long, we need to read the rest of the length # If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80: while length[-1] & 0x80 == 0x80:
length += await self._reader.readexactly(1) add_length = self._read_exactly(1)
if add_length is None:
return
length += add_length
# If the message length was longer than 1 byte, we need to read the # If the message length was longer than 1 byte, we need to read the
# message type # message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80: while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1) add_msg_type = self._read_exactly(1)
if add_msg_type is None:
return
msg_type += add_msg_type
length_int = bytes_to_varuint(length) length_int = bytes_to_varuint(bytes(length))
assert length_int is not None assert length_int is not None
msg_type_int = bytes_to_varuint(msg_type) msg_type_int = bytes_to_varuint(bytes(msg_type))
assert msg_type_int is not None assert msg_type_int is not None
if length_int == 0: if length_int == 0:
return Packet(type=msg_type_int, data=b"") self._callback_packet(msg_type_int, b"")
# If we have more data, continue processing
continue
data = await self._reader.readexactly(length_int) packet_data = self._read_exactly(length_int)
return Packet(type=msg_type_int, data=data) if packet_data is None:
except SOCKET_ERRORS as err: return
if (
isinstance(err, asyncio.IncompleteReadError) self._callback_packet(msg_type_int, packet_data)
and self._closed_event.is_set() # If we have more data, continue processing
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
def _decode_noise_psk(psk: str) -> bytes: def _decode_noise_psk(psk: str) -> bytes:
@ -176,34 +223,46 @@ def _decode_noise_psk(psk: str) -> bytes:
return psk_bytes return psk_bytes
class NoiseConnectionState(Enum):
"""Noise connection state."""
HELLO = 1
HANDSHAKE = 2
READY = 3
CLOSED = 4
class APINoiseFrameHelper(APIFrameHelper): class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections.""" """Frame helper for noise encrypted connections."""
def __init__( def __init__(
self, self,
reader: asyncio.StreamReader, on_pkt: Callable[[Packet], None],
writer: asyncio.StreamWriter, on_error: Callable[[Exception], None],
noise_psk: str, noise_psk: str,
expected_name: Optional[str],
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
super().__init__(reader, writer) super().__init__(on_pkt, on_error)
self._ready_event = asyncio.Event() self._ready_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
self._noise_psk = noise_psk self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._setup_proto()
@property @property
def ready(self) -> bool: def ready(self) -> bool:
"""Return if the connection is ready.""" """Return if the connection is ready."""
return self._ready_event.is_set() return self._ready_event.is_set()
async def close(self) -> None: def close(self) -> None:
"""Close the connection.""" """Close the connection."""
# Make sure we set the ready event if its not already set # Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we # so that we don't block forever on the ready event if we
# are waiting for the handshake to complete. # are waiting for the handshake to complete.
self._ready_event.set() self._ready_event.set()
self._closed_event.set() self._state = NoiseConnectionState.CLOSED
self._writer.close() super().close()
def _write_frame(self, frame: bytes) -> None: def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock. """Write a packet to the socket, the caller should not have the lock.
@ -212,6 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper):
to avoid locking. to avoid locking.
""" """
_LOGGER.debug("Sending frame %s", frame.hex()) _LOGGER.debug("Sending frame %s", frame.hex())
assert self._transport is not None, "Transport is not set"
try: try:
header = bytes( header = bytes(
@ -221,39 +281,46 @@ class APINoiseFrameHelper(APIFrameHelper):
len(frame) & 0xFF, len(frame) & 0xFF,
] ]
) )
self._writer.write(header + frame) self._transport.write(header + frame)
except OSError as err: except OSError as err:
raise SocketAPIError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame_with_lock(self) -> bytes: async def perform_handshake(self) -> None:
"""Read a frame from the socket, the caller is responsible for having the lock.""" """Perform the handshake with the server."""
assert self.read_lock.locked(), "_read_frame_with_lock called without lock" self._send_hello()
try: try:
header = await self._reader.readexactly(3) async with async_timeout.timeout(60.0):
await self._ready_event.wait()
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
def data_received(self, data: bytes) -> None:
self._buffer += data
while len(self._buffer) >= 3:
header = self._init_read(3)
assert header is not None, "Buffer should have at least 3 bytes"
if header[0] != 0x01: if header[0] != 0x01:
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
)
msg_size = (header[1] << 8) | header[2] msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size) frame = self._read_exactly(msg_size)
except SOCKET_ERRORS as err: if frame is None:
if ( return
isinstance(err, asyncio.IncompleteReadError)
and self._closed_event.is_set()
):
raise SocketClosedAPIError(
f"Socket closed while reading data: {err}"
) from err
raise SocketAPIError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex()) try:
return frame self.STATE_TO_CALLABLE[self._state](self, frame)
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
del self._buffer[: self._pos]
async def _perform_handshake(self, expected_name: Optional[str]) -> None: def _send_hello(self) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock.""" """Send a ClientHello to the server."""
assert self.read_lock.locked(), "_perform_handshake called without lock"
self._write_frame(b"") # ClientHello self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_with_lock() # ServerHello def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
if not server_hello: if not server_hello:
raise HandshakeAPIError("ServerHello is empty") raise HandshakeAPIError("ServerHello is empty")
@ -273,76 +340,60 @@ class APINoiseFrameHelper(APIFrameHelper):
if server_name_i != -1: if server_name_i != -1:
# server name found, this extension was added in 2022.2 # server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode() server_name = server_hello[1:server_name_i].decode()
if expected_name is not None and expected_name != server_name: if self._expected_name is not None and self._expected_name != server_name:
raise BadNameAPIError( raise BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name f"Server sent a different name '{server_name}'", server_name
) )
self._state = NoiseConnectionState.HANDSHAKE
self._send_handshake()
def _setup_proto(self) -> None:
"""Set up the noise protocol."""
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator() self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk)) self._proto.set_psks(_decode_noise_psk(self._noise_psk))
self._proto.set_prologue(prologue) self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
self._proto.start_handshake() self._proto.start_handshake()
def _send_handshake(self) -> None:
"""Send the handshake message."""
self._write_frame(b"\x00" + self._proto.write_message())
def _handle_handshake(self, msg: bytearray) -> None:
_LOGGER.debug("Starting handshake...") _LOGGER.debug("Starting handshake...")
do_write = True if msg[0] != 0:
while not self._proto.handshake_finished: explanation = msg[1:].decode()
if do_write: if explanation == "Handshake MAC failure":
msg = self._proto.write_message() raise InvalidEncryptionKeyAPIError("Invalid encryption key")
self._write_frame(b"\x00" + msg) raise HandshakeAPIError(f"Handshake failure: {explanation}")
else: self._proto.read_message(msg[1:])
msg = await self._read_frame_with_lock() _LOGGER.debug("Handshake complete")
if not msg: self._state = NoiseConnectionState.READY
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
self._ready_event.set() self._ready_event.set()
async def perform_handshake(self, expected_name: Optional[str]) -> None:
"""Perform the handshake with the server."""
# Allow up to 60 seconds for handhsake
try:
async with self.read_lock, async_timeout.timeout(60.0):
await self._perform_handshake(expected_name)
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
def write_packet(self, packet: Packet) -> None: def write_packet(self, packet: Packet) -> None:
"""Write a packet to the socket.""" """Write a packet to the socket."""
padding = 0 self._write_frame(
data = ( self._proto.encrypt(
bytes( (
[ bytes(
(packet.type >> 8) & 0xFF, [
(packet.type >> 0) & 0xFF, (packet.type >> 8) & 0xFF,
(len(packet.data) >> 8) & 0xFF, (packet.type >> 0) & 0xFF,
(len(packet.data) >> 0) & 0xFF, (len(packet.data) >> 8) & 0xFF,
] (len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
)
) )
+ packet.data
+ b"\x00" * padding
) )
assert self._proto is not None
frame = self._proto.encrypt(data)
self._write_frame(frame)
async def wait_for_ready(self) -> None: def _handle_frame(self, frame: bytearray) -> None:
"""Wait for the connection to be ready.""" """Handle an incoming frame."""
await self._ready_event.wait()
async def read_packet_with_lock(self) -> Packet:
"""Read a packet from the socket, the caller is responsible for having the lock."""
frame = await self._read_frame_with_lock()
assert self._proto is not None assert self._proto is not None
msg = self._proto.decrypt(frame) msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4: if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}") raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1] pkt_type = (msg[0] << 8) | msg[1]
@ -350,4 +401,17 @@ class APINoiseFrameHelper(APIFrameHelper):
if data_len + 4 > len(msg): if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len] data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data) return self._on_pkt(Packet(pkt_type, data))
def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray
) -> None:
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError("Connection closed"))
STATE_TO_CALLABLE = {
NoiseConnectionState.HELLO: _handle_hello,
NoiseConnectionState.HANDSHAKE: _handle_handshake,
NoiseConnectionState.READY: _handle_frame,
NoiseConnectionState.CLOSED: _handle_closed,
}

View File

@ -373,7 +373,7 @@ class APIClient:
image_stream[msg.key] = data image_stream[msg.key] = data
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg, msg_types SubscribeStatesRequest(), on_msg, msg_types
) )
@ -394,7 +394,7 @@ class APIClient:
if dump_config is not None: if dump_config is not None:
req.dump_config = dump_config req.dump_config = dump_config
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
req, on_msg, (SubscribeLogsResponse,) req, on_msg, (SubscribeLogsResponse,)
) )
@ -407,7 +407,7 @@ class APIClient:
on_service_call(HomeassistantServiceCall.from_pb(msg)) on_service_call(HomeassistantServiceCall.from_pb(msg))
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
SubscribeHomeassistantServicesRequest(), SubscribeHomeassistantServicesRequest(),
on_msg, on_msg,
(HomeassistantServiceResponse,), (HomeassistantServiceResponse,),
@ -451,7 +451,7 @@ class APIClient:
on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types
) )
@ -472,7 +472,7 @@ class APIClient:
on_bluetooth_connections_free_update(resp.free, resp.limit) on_bluetooth_connections_free_update(resp.free, resp.limit)
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types
) )
@ -518,7 +518,7 @@ class APIClient:
_LOGGER.debug("%s: Using connection version 1", address) _LOGGER.debug("%s: Using connection version 1", address)
request_type = BluetoothDeviceRequestType.CONNECT request_type = BluetoothDeviceRequestType.CONNECT
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
BluetoothDeviceRequest( BluetoothDeviceRequest(
address=address, address=address,
request_type=request_type, request_type=request_type,
@ -581,7 +581,7 @@ class APIClient:
self._check_authenticated() self._check_authenticated()
assert self._connection is not None assert self._connection is not None
await self._connection.send_message( self._connection.send_message(
BluetoothDeviceRequest( BluetoothDeviceRequest(
address=address, address=address,
request_type=BluetoothDeviceRequestType.DISCONNECT, request_type=BluetoothDeviceRequestType.DISCONNECT,
@ -661,7 +661,7 @@ class APIClient:
if not response: if not response:
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
return return
await self._send_bluetooth_message_await_response( await self._send_bluetooth_message_await_response(
@ -709,7 +709,7 @@ class APIClient:
if not wait_for_response: if not wait_for_response:
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
return return
await self._send_bluetooth_message_await_response( await self._send_bluetooth_message_await_response(
@ -762,7 +762,7 @@ class APIClient:
self._check_authenticated() self._check_authenticated()
await self._connection.send_message( self._connection.send_message(
BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False) BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False)
) )
@ -777,7 +777,7 @@ class APIClient:
on_state_sub(msg.entity_id, msg.attribute) on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None assert self._connection is not None
await self._connection.send_message_callback_response( self._connection.send_message_callback_response(
SubscribeHomeAssistantStatesRequest(), SubscribeHomeAssistantStatesRequest(),
on_msg, on_msg,
(SubscribeHomeAssistantStateResponse,), (SubscribeHomeAssistantStateResponse,),
@ -789,7 +789,7 @@ class APIClient:
self._check_authenticated() self._check_authenticated()
assert self._connection is not None assert self._connection is not None
await self._connection.send_message( self._connection.send_message(
HomeAssistantStateResponse( HomeAssistantStateResponse(
entity_id=entity_id, entity_id=entity_id,
state=state, state=state,
@ -829,7 +829,7 @@ class APIClient:
req.legacy_command = LegacyCoverCommand.CLOSE req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True req.has_legacy_command = True
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def fan_command( async def fan_command(
self, self,
@ -860,7 +860,7 @@ class APIClient:
req.has_direction = True req.has_direction = True
req.direction = direction req.direction = direction
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def light_command( async def light_command(
self, self,
@ -921,7 +921,7 @@ class APIClient:
req.has_effect = True req.has_effect = True
req.effect = effect req.effect = effect
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def switch_command(self, key: int, state: bool) -> None: async def switch_command(self, key: int, state: bool) -> None:
self._check_authenticated() self._check_authenticated()
@ -930,7 +930,7 @@ class APIClient:
req.key = key req.key = key
req.state = state req.state = state
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def climate_command( async def climate_command(
self, self,
@ -982,7 +982,7 @@ class APIClient:
req.has_custom_preset = True req.has_custom_preset = True
req.custom_preset = custom_preset req.custom_preset = custom_preset
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def number_command(self, key: int, state: float) -> None: async def number_command(self, key: int, state: float) -> None:
self._check_authenticated() self._check_authenticated()
@ -991,7 +991,7 @@ class APIClient:
req.key = key req.key = key
req.state = state req.state = state
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def select_command(self, key: int, state: str) -> None: async def select_command(self, key: int, state: str) -> None:
self._check_authenticated() self._check_authenticated()
@ -1000,7 +1000,7 @@ class APIClient:
req.key = key req.key = key
req.state = state req.state = state
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def siren_command( async def siren_command(
self, self,
@ -1027,7 +1027,7 @@ class APIClient:
req.duration = duration req.duration = duration
req.has_duration = True req.has_duration = True
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def button_command(self, key: int) -> None: async def button_command(self, key: int) -> None:
self._check_authenticated() self._check_authenticated()
@ -1035,7 +1035,7 @@ class APIClient:
req = ButtonCommandRequest() req = ButtonCommandRequest()
req.key = key req.key = key
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def lock_command( async def lock_command(
self, self,
@ -1051,7 +1051,7 @@ class APIClient:
if code is not None: if code is not None:
req.code = code req.code = code
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def media_player_command( async def media_player_command(
self, self,
@ -1075,7 +1075,7 @@ class APIClient:
req.media_url = media_url req.media_url = media_url
req.has_media_url = True req.has_media_url = True
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def execute_service( async def execute_service(
self, service: UserService, data: ExecuteServiceDataType self, service: UserService, data: ExecuteServiceDataType
@ -1113,7 +1113,7 @@ class APIClient:
# pylint: disable=no-member # pylint: disable=no-member
req.args.extend(args) req.args.extend(args)
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def _request_image( async def _request_image(
self, *, single: bool = False, stream: bool = False self, *, single: bool = False, stream: bool = False
@ -1122,7 +1122,7 @@ class APIClient:
req.single = single req.single = single
req.stream = stream req.stream = stream
assert self._connection is not None assert self._connection is not None
await self._connection.send_message(req) self._connection.send_message(req)
async def request_single_image(self) -> None: async def request_single_image(self) -> None:
await self._request_image(single=True) await self._request_image(single=True)

View File

@ -5,7 +5,7 @@ import socket
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union
import async_timeout import async_timeout
from google.protobuf import message from google.protobuf import message
@ -40,7 +40,6 @@ from .core import (
ReadFailedAPIError, ReadFailedAPIError,
ResolveAPIError, ResolveAPIError,
SocketAPIError, SocketAPIError,
SocketClosedAPIError,
TimeoutAPIError, TimeoutAPIError,
) )
from .model import APIVersion from .model import APIVersion
@ -112,10 +111,6 @@ class APIConnection:
self._ping_stop_event = asyncio.Event() self._ping_stop_event = asyncio.Event()
self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue()
self._process_task: Optional[asyncio.Task[None]] = None
self._connect_lock: asyncio.Lock = asyncio.Lock() self._connect_lock: asyncio.Lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None self._cleanup_task: Optional[asyncio.Task[None]] = None
@ -138,27 +133,10 @@ class APIConnection:
async with self._connect_lock: async with self._connect_lock:
_LOGGER.debug("Cleaning up connection to %s", self.log_name) _LOGGER.debug("Cleaning up connection to %s", self.log_name)
# Tell the process loop to stop
self._to_process.put_nowait(None)
if self._frame_helper is not None: if self._frame_helper is not None:
await self._frame_helper.close() self._frame_helper.close()
self._frame_helper = None self._frame_helper = None
if self._process_task is not None:
self._process_task.cancel()
try:
await self._process_task
except asyncio.CancelledError:
pass
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(
"Unexpected exception in process task: %s",
err,
exc_info=err,
)
self._process_task = None
if self._socket is not None: if self._socket is not None:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
@ -231,24 +209,31 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None: async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
reader, writer = await asyncio.open_connection( fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper]
sock=self._socket, limit=BUFFER_SIZE loop = asyncio.get_event_loop()
) # Set buffer limit to 1MB
if self._params.noise_psk is None: if self._params.noise_psk is None:
self._frame_helper = APIPlaintextFrameHelper(reader, writer) _, fh = await loop.create_connection(
else: lambda: APIPlaintextFrameHelper(
fh = self._frame_helper = APINoiseFrameHelper( on_pkt=self._process_packet,
reader, writer, self._params.noise_psk on_error=self._report_fatal_error_and_cleanup_task,
),
sock=self._socket,
)
else:
_, fh = await loop.create_connection(
lambda: APINoiseFrameHelper(
noise_psk=self._params.noise_psk,
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error_and_cleanup_task,
),
sock=self._socket,
) )
await fh.perform_handshake(self._params.expected_name)
self._frame_helper = fh
self._connection_state = ConnectionState.SOCKET_OPENED self._connection_state = ConnectionState.SOCKET_OPENED
await fh.perform_handshake()
# Create read loop
asyncio.create_task(self._read_loop())
# Create process loop
self._process_task = asyncio.create_task(self._process_loop())
async def _connect_hello(self) -> None: async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version.""" """Step 4 in connect process: send hello and get api version."""
@ -292,10 +277,7 @@ class APIConnection:
"""Step 5 in connect process: start the ping loop.""" """Step 5 in connect process: start the ping loop."""
async def _keep_alive_loop() -> None: async def _keep_alive_loop() -> None:
while True: while self._is_socket_open:
if not self._is_socket_open:
return
# Wait for keepalive seconds, or ping stop event, whichever happens first # Wait for keepalive seconds, or ping stop event, whichever happens first
try: try:
async with async_timeout.timeout(self._params.keepalive): async with async_timeout.timeout(self._params.keepalive):
@ -404,22 +386,20 @@ class APIConnection:
def is_authenticated(self) -> bool: def is_authenticated(self) -> bool:
return self.is_connected and self._is_authenticated return self.is_connected and self._is_authenticated
async def send_message(self, msg: message.Message) -> None: def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote.""" """Send a protobuf message to the remote."""
if not self._is_socket_open: if not self._is_socket_open:
raise APIConnectionError("Connection isn't established yet") raise APIConnectionError("Connection isn't established yet")
frame_helper = self._frame_helper
assert frame_helper is not None
assert frame_helper.ready, "Frame helper not ready"
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg)) message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
if not message_type: if not message_type:
raise ValueError(f"Message type id not found for type {type(msg)}") raise ValueError(f"Message type id not found for type {type(msg)}")
encoded = msg.SerializeToString() encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
frame_helper = self._frame_helper
assert frame_helper is not None
if not frame_helper.ready:
await frame_helper.wait_for_ready()
try: try:
frame_helper.write_packet( frame_helper.write_packet(
Packet( Packet(
@ -431,7 +411,7 @@ class APIConnection:
# If writing packet fails, we don't know what state the frames # If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection # are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err) _LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
await self._report_fatal_error(err) self._report_fatal_error_and_cleanup_task(err)
raise raise
def add_message_callback( def add_message_callback(
@ -454,7 +434,7 @@ class APIConnection:
for msg_type in msg_types: for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message) self._message_handlers[msg_type].remove(on_message)
async def send_message_callback_response( def send_message_callback_response(
self, self,
send_msg: message.Message, send_msg: message.Message,
on_message: Callable[[Any], None], on_message: Callable[[Any], None],
@ -464,7 +444,7 @@ class APIConnection:
for msg_type in msg_types: for msg_type in msg_types:
self._message_handlers.setdefault(msg_type, []).append(on_message) self._message_handlers.setdefault(msg_type, []).append(on_message)
try: try:
await self.send_message(send_msg) self.send_message(send_msg)
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
for msg_type in msg_types: for msg_type in msg_types:
self._message_handlers[msg_type].remove(on_message) self._message_handlers[msg_type].remove(on_message)
@ -514,7 +494,7 @@ class APIConnection:
# the await is cancelled # the await is cancelled
try: try:
await self.send_message(send_msg) self.send_message(send_msg)
async with async_timeout.timeout(timeout): async with async_timeout.timeout(timeout):
await fut await fut
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
@ -545,6 +525,25 @@ class APIConnection:
return res[0] return res[0]
def _handle_fatal_error(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation."""
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._read_exception_handlers.clear()
def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
can no longer be used.
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
self._handle_fatal_error(err)
asyncio.create_task(self._cleanup())
async def _report_fatal_error(self, err: Exception) -> None: async def _report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occurred during an operation. """Report a fatal error that occurred during an operation.
@ -554,106 +553,55 @@ class APIConnection:
The connection will be closed, all exception handlers notified. The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so. This method does not log the error, the call site should do so.
""" """
self._connection_state = ConnectionState.CLOSED self._handle_fatal_error(err)
for handler in self._read_exception_handlers[:]:
handler(err)
await self._cleanup() await self._cleanup()
async def _process_loop(self) -> None: def _process_packet(self, pkt: Packet) -> None:
to_process = self._to_process """Process a packet from the socket."""
while True: msg_type_proto = pkt.type
try: if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
pkt = await to_process.get() _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto)
except RuntimeError: return
break
if pkt is None: msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
# Socket closed but task isn't cancelled yet
break
msg_type_proto = pkt.type
if msg_type_proto not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto
)
continue
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try:
msg.ParseFromString(pkt.data)
except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name,
pkt.type,
pkt.data,
e,
exc_info=True,
)
await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
raise
msg_type = type(msg)
_LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
)
for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type in INTERNAL_MESSAGE_TYPES:
await self._handle_internal_messages(msg)
async def _read_loop(self) -> None:
frame_helper = self._frame_helper
assert frame_helper is not None
await frame_helper.wait_for_ready()
to_process = self._to_process
try: try:
# Once its ready, we hold the lock for the duration of the msg.ParseFromString(pkt.data)
# connection so we don't have to keep locking/unlocking except Exception as e:
async with frame_helper.read_lock:
while True:
to_process.put_nowait(await frame_helper.read_packet_with_lock())
except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log
_LOGGER.debug(
"%s: Socket closed, stopping read loop",
self.log_name,
)
await self._report_fatal_error(err)
except APIConnectionError as err:
_LOGGER.info( _LOGGER.info(
"%s: Error while reading incoming messages: %s", "%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name, self.log_name,
err, pkt.type,
) pkt.data,
await self._report_fatal_error(err) e,
except Exception as err: # pylint: disable=broad-except
_LOGGER.warning(
"%s: Unexpected error while reading incoming messages: %s",
self.log_name,
err,
exc_info=True, exc_info=True,
) )
await self._report_fatal_error(err) self._report_fatal_error_and_cleanup_task(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
raise
msg_type = type(msg)
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, msg_type, msg)
for handler in self._message_handlers.get(msg_type, [])[:]:
handler(msg)
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type not in INTERNAL_MESSAGE_TYPES:
return
async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, DisconnectRequest): if isinstance(msg, DisconnectRequest):
await self.send_message(DisconnectResponse()) self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED self._connection_state = ConnectionState.CLOSED
await self._cleanup() asyncio.create_task(self._cleanup())
elif isinstance(msg, PingRequest): elif isinstance(msg, PingRequest):
await self.send_message(PingResponse()) self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest): elif isinstance(msg, GetTimeRequest):
resp = GetTimeResponse() resp = GetTimeResponse()
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
await self.send_message(resp) self.send_message(resp)
async def _ping(self) -> None: async def _ping(self) -> None:
self._check_connected() self._check_connected()

View File

@ -1,7 +1,9 @@
import math import math
from functools import lru_cache
from typing import Optional from typing import Optional
@lru_cache(maxsize=1024)
def varuint_to_bytes(value: int) -> bytes: def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F: if value <= 0x7F:
return bytes([value]) return bytes([value])
@ -18,6 +20,7 @@ def varuint_to_bytes(value: int) -> bytes:
return ret return ret
@lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> Optional[int]: def bytes_to_varuint(value: bytes) -> Optional[int]:
result = 0 result = 0
bitpos = 0 bitpos = 0

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi.util import varuint_to_bytes from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -46,17 +46,20 @@ PREAMBLE = b"\x00"
) )
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
stream_reader = asyncio.StreamReader()
stream_writer = MagicMock()
for _ in range(5): for _ in range(5):
packets = []
stream_reader.feed_data(in_bytes) def _packet(pkt: Packet):
packets.append(pkt)
helper = APIPlaintextFrameHelper(stream_reader, stream_writer) def _on_error(exc: Exception):
raise exc
async with helper.read_lock: helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
pkt = await helper.read_packet_with_lock()
helper.data_received(in_bytes)
pkt = packets.pop()
assert pkt.type == pkt_type assert pkt.type == pkt_type
assert pkt.data == pkt_data assert pkt.data == pkt_data

View File

@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
def patch_response_callback(client: APIClient): def patch_response_callback(client: APIClient):
on_message = None on_message = None
async def patched(req, callback, msg_types): def patched(req, callback, msg_types):
nonlocal on_message nonlocal on_message
on_message = callback on_message = callback

View File

@ -4,6 +4,7 @@ import socket
import pytest import pytest
from mock import AsyncMock, MagicMock, Mock, patch from mock import AsyncMock, MagicMock, Mock, patch
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
@ -50,13 +51,26 @@ def socket_socket():
yield func yield func
def _get_mock_protocol():
def _on_packet(pkt: Packet):
pass
def _on_error(exc: Exception):
raise exc
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
protocol._connected_event.set()
protocol._transport = MagicMock()
return protocol
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect(conn, resolve_host, socket_socket, event_loop): async def test_connect(conn, resolve_host, socket_socket, event_loop):
with patch.object(event_loop, "sock_connect"), patch( loop = asyncio.get_event_loop()
"asyncio.open_connection", return_value=(None, None) protocol = _get_mock_protocol()
), patch.object(conn, "_read_loop"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
conn, "_connect_start_ping" loop, "create_connection", return_value=(MagicMock(), protocol)
), patch.object( ), patch.object(conn, "_connect_start_ping"), patch.object(
conn, "send_message_await_response", return_value=HelloResponse() conn, "send_message_await_response", return_value=HelloResponse()
): ):
await conn.connect(login=False) await conn.connect(login=False)
@ -66,14 +80,15 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_requires_encryption_propagates(conn): async def test_requires_encryption_propagates(conn):
with patch("asyncio.open_connection") as openc: loop = asyncio.get_event_loop()
reader = MagicMock() protocol = _get_mock_protocol()
writer = MagicMock() with patch.object(loop, "create_connection") as create_connection, patch.object(
openc.return_value = (reader, writer) protocol, "perform_handshake"
writer.drain = AsyncMock() ):
reader.readexactly = AsyncMock() create_connection.return_value = (MagicMock(), protocol)
reader.readexactly.return_value = b"\x01"
await conn._connect_init_frame_helper() await conn._connect_init_frame_helper()
with pytest.raises(RequiresEncryptionAPIError): with pytest.raises(RequiresEncryptionAPIError):
protocol.data_received(b"\x01\x00\x00")
await conn._connect_hello() await conn._connect_hello()