mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-15 10:55:13 +01:00
Reduce duplicate code between connection and frame helper (#763)
This commit is contained in:
parent
5fb9c9243b
commit
1b51530642
@ -13,7 +13,7 @@ cdef class APIFrameHelper:
|
|||||||
cdef APIConnection _connection
|
cdef APIConnection _connection
|
||||||
cdef object _transport
|
cdef object _transport
|
||||||
cdef public object _writer
|
cdef public object _writer
|
||||||
cdef public object _ready_future
|
cdef public object ready_future
|
||||||
cdef bytes _buffer
|
cdef bytes _buffer
|
||||||
cdef unsigned int _buffer_len
|
cdef unsigned int _buffer_len
|
||||||
cdef unsigned int _pos
|
cdef unsigned int _pos
|
||||||
|
@ -5,7 +5,7 @@ import logging
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING, Callable, cast
|
from typing import TYPE_CHECKING, Callable, cast
|
||||||
|
|
||||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
from ..core import SocketClosedAPIError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..connection import APIConnection
|
from ..connection import APIConnection
|
||||||
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
|||||||
|
|
||||||
_int = int
|
_int = int
|
||||||
_bytes = bytes
|
_bytes = bytes
|
||||||
|
_float = float
|
||||||
|
|
||||||
|
|
||||||
class APIFrameHelper:
|
class APIFrameHelper:
|
||||||
@ -33,7 +34,7 @@ class APIFrameHelper:
|
|||||||
"_connection",
|
"_connection",
|
||||||
"_transport",
|
"_transport",
|
||||||
"_writer",
|
"_writer",
|
||||||
"_ready_future",
|
"ready_future",
|
||||||
"_buffer",
|
"_buffer",
|
||||||
"_buffer_len",
|
"_buffer_len",
|
||||||
"_pos",
|
"_pos",
|
||||||
@ -53,7 +54,7 @@ class APIFrameHelper:
|
|||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._transport: asyncio.Transport | None = None
|
self._transport: asyncio.Transport | None = None
|
||||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||||
self._ready_future = self._loop.create_future()
|
self.ready_future = self._loop.create_future()
|
||||||
self._buffer: bytes | None = None
|
self._buffer: bytes | None = None
|
||||||
self._buffer_len = 0
|
self._buffer_len = 0
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
@ -65,8 +66,8 @@ class APIFrameHelper:
|
|||||||
self._log_name = log_name
|
self._log_name = log_name
|
||||||
|
|
||||||
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||||
if not self._ready_future.done():
|
if not self.ready_future.done():
|
||||||
self._ready_future.set_exception(exc)
|
self.ready_future.set_exception(exc)
|
||||||
|
|
||||||
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
"""Add data to the buffer."""
|
"""Add data to the buffer."""
|
||||||
@ -135,22 +136,6 @@ class APIFrameHelper:
|
|||||||
bitpos += 7
|
bitpos += 7
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
async def perform_handshake(self, timeout: float) -> None:
|
|
||||||
"""Perform the handshake with the server."""
|
|
||||||
handshake_handle = self._loop.call_at(
|
|
||||||
self._loop.time() + timeout,
|
|
||||||
self._set_ready_future_exception,
|
|
||||||
asyncio.TimeoutError,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await self._ready_future
|
|
||||||
except asyncio.TimeoutError as err:
|
|
||||||
raise HandshakeAPIError(
|
|
||||||
f"{self._log_name}: Timeout during handshake"
|
|
||||||
) from err
|
|
||||||
finally:
|
|
||||||
handshake_handle.cancel()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def write_packets(
|
def write_packets(
|
||||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
self._ready_future.set_result(None)
|
self.ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packets(
|
def write_packets(
|
||||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
"""Handle a new connection."""
|
"""Handle a new connection."""
|
||||||
super().connection_made(transport)
|
super().connection_made(transport)
|
||||||
self._ready_future.set_result(None)
|
self.ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packets(
|
def write_packets(
|
||||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
@ -71,7 +71,6 @@ from .api_pb2 import ( # type: ignore
|
|||||||
VoiceAssistantResponse,
|
VoiceAssistantResponse,
|
||||||
)
|
)
|
||||||
from .client_callbacks import (
|
from .client_callbacks import (
|
||||||
handle_timeout,
|
|
||||||
on_ble_raw_advertisement_response,
|
on_ble_raw_advertisement_response,
|
||||||
on_bluetooth_connections_free_response,
|
on_bluetooth_connections_free_response,
|
||||||
on_bluetooth_device_connection_response,
|
on_bluetooth_device_connection_response,
|
||||||
@ -81,7 +80,7 @@ from .client_callbacks import (
|
|||||||
on_state_msg,
|
on_state_msg,
|
||||||
on_subscribe_home_assistant_state_response,
|
on_subscribe_home_assistant_state_response,
|
||||||
)
|
)
|
||||||
from .connection import APIConnection, ConnectionParams
|
from .connection import APIConnection, ConnectionParams, handle_timeout
|
||||||
from .core import (
|
from .core import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
BluetoothGATTAPIError,
|
BluetoothGATTAPIError,
|
||||||
|
@ -8,5 +8,3 @@ cdef object CameraImageResponse, CameraState
|
|||||||
cdef object HomeassistantServiceCall
|
cdef object HomeassistantServiceCall
|
||||||
|
|
||||||
cdef object BluetoothLEAdvertisement
|
cdef object BluetoothLEAdvertisement
|
||||||
|
|
||||||
cdef object asyncio_TimeoutError
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -98,12 +97,6 @@ def on_subscribe_home_assistant_state_response(
|
|||||||
on_state_sub(msg.entity_id, msg.attribute)
|
on_state_sub(msg.entity_id, msg.attribute)
|
||||||
|
|
||||||
|
|
||||||
def handle_timeout(fut: Future[None]) -> None:
|
|
||||||
"""Handle a timeout."""
|
|
||||||
if not fut.done():
|
|
||||||
fut.set_exception(asyncio_TimeoutError)
|
|
||||||
|
|
||||||
|
|
||||||
def on_bluetooth_device_connection_response(
|
def on_bluetooth_device_connection_response(
|
||||||
connect_future: Future[None],
|
connect_future: Future[None],
|
||||||
address: int,
|
address: int,
|
||||||
|
@ -9,19 +9,21 @@ cdef dict PROTO_TO_MESSAGE_TYPE
|
|||||||
cdef set OPEN_STATES
|
cdef set OPEN_STATES
|
||||||
|
|
||||||
cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
||||||
|
cdef object HANDSHAKE_TIMEOUT
|
||||||
|
|
||||||
cdef bint TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
|
|
||||||
cdef object DISCONNECT_REQUEST_MESSAGE
|
cdef object DISCONNECT_REQUEST_MESSAGE
|
||||||
cdef object DISCONNECT_RESPONSE_MESSAGE
|
cdef tuple DISCONNECT_RESPONSE_MESSAGES
|
||||||
cdef object PING_REQUEST_MESSAGE
|
cdef tuple PING_REQUEST_MESSAGES
|
||||||
cdef object PING_RESPONSE_MESSAGE
|
cdef tuple PING_RESPONSE_MESSAGES
|
||||||
|
cdef object NO_PASSWORD_CONNECT_REQUEST
|
||||||
|
|
||||||
cdef object asyncio_timeout
|
cdef object asyncio_timeout
|
||||||
cdef object CancelledError
|
cdef object CancelledError
|
||||||
cdef object asyncio_TimeoutError
|
cdef object asyncio_TimeoutError
|
||||||
|
|
||||||
cdef object ConnectResponse
|
cdef object ConnectRequest, ConnectResponse
|
||||||
cdef object DisconnectRequest
|
cdef object DisconnectRequest
|
||||||
cdef object PingRequest
|
cdef object PingRequest
|
||||||
cdef object GetTimeRequest, GetTimeResponse
|
cdef object GetTimeRequest, GetTimeResponse
|
||||||
@ -53,6 +55,20 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
|
|||||||
cdef object CONNECTION_STATE_CONNECTED
|
cdef object CONNECTION_STATE_CONNECTED
|
||||||
cdef object CONNECTION_STATE_CLOSED
|
cdef object CONNECTION_STATE_CLOSED
|
||||||
|
|
||||||
|
cdef object make_hello_request
|
||||||
|
|
||||||
|
cpdef handle_timeout(object fut)
|
||||||
|
cpdef handle_complex_message(
|
||||||
|
object fut,
|
||||||
|
list responses,
|
||||||
|
object do_append,
|
||||||
|
object do_stop,
|
||||||
|
object resp,
|
||||||
|
)
|
||||||
|
|
||||||
|
cdef object _handle_timeout
|
||||||
|
cdef object _handle_complex_message
|
||||||
|
|
||||||
@cython.dataclasses.dataclass
|
@cython.dataclasses.dataclass
|
||||||
cdef class ConnectionParams:
|
cdef class ConnectionParams:
|
||||||
cdef public str address
|
cdef public str address
|
||||||
@ -91,43 +107,45 @@ cdef class APIConnection:
|
|||||||
cdef public str received_name
|
cdef public str received_name
|
||||||
cdef public object resolved_addr_info
|
cdef public object resolved_addr_info
|
||||||
|
|
||||||
cpdef send_message(self, object msg)
|
cpdef void send_message(self, object msg)
|
||||||
|
|
||||||
cdef send_messages(self, tuple messages)
|
cdef void send_messages(self, tuple messages)
|
||||||
|
|
||||||
@cython.locals(handlers=set, handlers_copy=set)
|
@cython.locals(handlers=set, handlers_copy=set)
|
||||||
cpdef void process_packet(self, object msg_type_proto, object data)
|
cpdef void process_packet(self, object msg_type_proto, object data)
|
||||||
|
|
||||||
cpdef _async_cancel_pong_timer(self)
|
cdef void _async_cancel_pong_timer(self)
|
||||||
|
|
||||||
cpdef _async_schedule_keep_alive(self, object now)
|
cdef void _async_schedule_keep_alive(self, object now)
|
||||||
|
|
||||||
cdef _cleanup(self)
|
cdef void _cleanup(self)
|
||||||
|
|
||||||
cpdef set_log_name(self, str name)
|
cpdef set_log_name(self, str name)
|
||||||
|
|
||||||
cdef _make_connect_request(self)
|
cdef _make_connect_request(self)
|
||||||
|
|
||||||
cdef _process_hello_resp(self, object resp)
|
cdef void _process_hello_resp(self, object resp)
|
||||||
|
|
||||||
cdef _process_login_response(self, object hello_response)
|
cdef void _process_login_response(self, object hello_response)
|
||||||
|
|
||||||
cdef _set_connection_state(self, object state)
|
cdef void _set_connection_state(self, object state)
|
||||||
|
|
||||||
cpdef report_fatal_error(self, Exception err)
|
cpdef report_fatal_error(self, Exception err)
|
||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
cpdef _handle_disconnect_request_internal(self, object msg)
|
cpdef void _handle_disconnect_request_internal(self, object msg)
|
||||||
|
|
||||||
cpdef _handle_ping_request_internal(self, object msg)
|
cpdef void _handle_ping_request_internal(self, object msg)
|
||||||
|
|
||||||
cpdef _handle_get_time_request_internal(self, object msg)
|
cpdef void _handle_get_time_request_internal(self, object msg)
|
||||||
|
|
||||||
cdef _set_fatal_exception_if_unset(self, Exception err)
|
cdef void _set_fatal_exception_if_unset(self, Exception err)
|
||||||
|
|
||||||
|
cdef void _register_internal_message_handlers(self)
|
||||||
|
@ -12,7 +12,7 @@ import time
|
|||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from functools import partial
|
from functools import lru_cache, partial
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -63,9 +63,10 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
|
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
|
||||||
|
|
||||||
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
||||||
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
|
DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
|
||||||
PING_REQUEST_MESSAGE = PingRequest()
|
PING_REQUEST_MESSAGES = (PingRequest(),)
|
||||||
PING_RESPONSE_MESSAGE = PingResponse()
|
PING_RESPONSE_MESSAGES = (PingResponse(),)
|
||||||
|
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
|
||||||
|
|
||||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||||
|
|
||||||
@ -131,6 +132,44 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
|
|||||||
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
|
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hello_request(client_info: str) -> HelloRequest:
|
||||||
|
"""Make a HelloRequest."""
|
||||||
|
return HelloRequest(
|
||||||
|
client_info=client_info, api_version_major=1, api_version_minor=9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
|
||||||
|
make_hello_request = _cached_make_hello_request
|
||||||
|
|
||||||
|
|
||||||
|
def handle_timeout(fut: asyncio.Future[None]) -> None:
|
||||||
|
"""Handle a timeout."""
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(asyncio_TimeoutError)
|
||||||
|
|
||||||
|
|
||||||
|
_handle_timeout = handle_timeout
|
||||||
|
|
||||||
|
|
||||||
|
def handle_complex_message(
|
||||||
|
fut: asyncio.Future[None],
|
||||||
|
responses: list[message.Message],
|
||||||
|
do_append: Callable[[message.Message], bool] | None,
|
||||||
|
do_stop: Callable[[message.Message], bool] | None,
|
||||||
|
resp: message.Message,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a message that is part of a response."""
|
||||||
|
if not fut.done():
|
||||||
|
if do_append is None or do_append(resp):
|
||||||
|
responses.append(resp)
|
||||||
|
if do_stop is None or do_stop(resp):
|
||||||
|
fut.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
|
_handle_complex_message = handle_complex_message
|
||||||
|
|
||||||
|
|
||||||
class APIConnection:
|
class APIConnection:
|
||||||
"""This class represents _one_ connection to a remote native API device.
|
"""This class represents _one_ connection to a remote native API device.
|
||||||
|
|
||||||
@ -331,12 +370,11 @@ class APIConnection:
|
|||||||
async def _connect_init_frame_helper(self) -> None:
|
async def _connect_init_frame_helper(self) -> None:
|
||||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
|
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
|
||||||
loop = self._loop
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._socket is not None
|
assert self._socket is not None
|
||||||
|
|
||||||
if (noise_psk := self._params.noise_psk) is None:
|
if (noise_psk := self._params.noise_psk) is None:
|
||||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
_, fh = await self._loop.create_connection( # type: ignore[type-var]
|
||||||
lambda: APIPlaintextFrameHelper(
|
lambda: APIPlaintextFrameHelper(
|
||||||
connection=self,
|
connection=self,
|
||||||
client_info=self._params.client_info,
|
client_info=self._params.client_info,
|
||||||
@ -345,7 +383,7 @@ class APIConnection:
|
|||||||
sock=self._socket,
|
sock=self._socket,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
_, fh = await self._loop.create_connection( # type: ignore[type-var]
|
||||||
lambda: APINoiseFrameHelper(
|
lambda: APINoiseFrameHelper(
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
expected_name=self._params.expected_name,
|
expected_name=self._params.expected_name,
|
||||||
@ -359,24 +397,24 @@ class APIConnection:
|
|||||||
# Set the frame helper right away to ensure
|
# Set the frame helper right away to ensure
|
||||||
# the socket gets closed if we fail to handshake
|
# the socket gets closed if we fail to handshake
|
||||||
self._frame_helper = fh
|
self._frame_helper = fh
|
||||||
|
handshake_handle = self._loop.call_at(
|
||||||
|
self._loop.time() + HANDSHAKE_TIMEOUT,
|
||||||
|
_handle_timeout,
|
||||||
|
self._frame_helper.ready_future,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
await self._frame_helper.ready_future
|
||||||
except asyncio_TimeoutError as err:
|
except asyncio_TimeoutError as err:
|
||||||
raise TimeoutAPIError("Handshake timed out") from err
|
raise TimeoutAPIError("Handshake timed out") from err
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||||
|
finally:
|
||||||
|
handshake_handle.cancel()
|
||||||
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
|
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
|
||||||
|
|
||||||
async def _connect_hello_login(self, login: bool) -> None:
|
async def _connect_hello_login(self, login: bool) -> None:
|
||||||
"""Step 4 in connect process: send hello and login and get api version."""
|
"""Step 4 in connect process: send hello and login and get api version."""
|
||||||
messages = [
|
messages = [make_hello_request(self._params.client_info)]
|
||||||
HelloRequest(
|
|
||||||
client_info=self._params.client_info,
|
|
||||||
api_version_major=1,
|
|
||||||
api_version_minor=9,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
msg_types = [HelloResponse]
|
msg_types = [HelloResponse]
|
||||||
if login:
|
if login:
|
||||||
messages.append(self._make_connect_request())
|
messages.append(self._make_connect_request())
|
||||||
@ -447,16 +485,15 @@ class APIConnection:
|
|||||||
|
|
||||||
def _async_send_keep_alive(self) -> None:
|
def _async_send_keep_alive(self) -> None:
|
||||||
"""Send a keep alive message."""
|
"""Send a keep alive message."""
|
||||||
loop = self._loop
|
now = self._loop.time()
|
||||||
now = loop.time()
|
|
||||||
|
|
||||||
if self._send_pending_ping:
|
if self._send_pending_ping:
|
||||||
self.send_messages((PING_REQUEST_MESSAGE,))
|
self.send_messages(PING_REQUEST_MESSAGES)
|
||||||
if self._pong_timer is None:
|
if self._pong_timer is None:
|
||||||
# Do not reset the timer if it's already set
|
# Do not reset the timer if it's already set
|
||||||
# since the only thing we want to reset the timer
|
# since the only thing we want to reset the timer
|
||||||
# is if we receive a pong.
|
# is if we receive a pong.
|
||||||
self._pong_timer = loop.call_at(
|
self._pong_timer = self._loop.call_at(
|
||||||
now + self._keep_alive_timeout, self._async_pong_not_received
|
now + self._keep_alive_timeout, self._async_pong_not_received
|
||||||
)
|
)
|
||||||
elif self._debug_enabled:
|
elif self._debug_enabled:
|
||||||
@ -600,10 +637,9 @@ class APIConnection:
|
|||||||
|
|
||||||
def _make_connect_request(self) -> ConnectRequest:
|
def _make_connect_request(self) -> ConnectRequest:
|
||||||
"""Make a ConnectRequest."""
|
"""Make a ConnectRequest."""
|
||||||
connect = ConnectRequest()
|
|
||||||
if self._params.password is not None:
|
if self._params.password is not None:
|
||||||
connect.password = self._params.password
|
return ConnectRequest(password=self._params.password)
|
||||||
return connect
|
return NO_PASSWORD_CONNECT_REQUEST
|
||||||
|
|
||||||
def send_message(self, msg: message.Message) -> None:
|
def send_message(self, msg: message.Message) -> None:
|
||||||
"""Send a message to the remote."""
|
"""Send a message to the remote."""
|
||||||
@ -679,26 +715,6 @@ class APIConnection:
|
|||||||
# we register the handler after sending the message
|
# we register the handler after sending the message
|
||||||
return self.add_message_callback(on_message, msg_types)
|
return self.add_message_callback(on_message, msg_types)
|
||||||
|
|
||||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
|
||||||
"""Handle a timeout."""
|
|
||||||
if not fut.done():
|
|
||||||
fut.set_exception(asyncio_TimeoutError)
|
|
||||||
|
|
||||||
def _handle_complex_message(
|
|
||||||
self,
|
|
||||||
fut: asyncio.Future[None],
|
|
||||||
responses: list[message.Message],
|
|
||||||
do_append: Callable[[message.Message], bool] | None,
|
|
||||||
do_stop: Callable[[message.Message], bool] | None,
|
|
||||||
resp: message.Message,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a message that is part of a response."""
|
|
||||||
if not fut.done():
|
|
||||||
if do_append is None or do_append(resp):
|
|
||||||
responses.append(resp)
|
|
||||||
if do_stop is None or do_stop(resp):
|
|
||||||
fut.set_result(None)
|
|
||||||
|
|
||||||
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
|
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
|
||||||
self,
|
self,
|
||||||
messages: tuple[message.Message, ...],
|
messages: tuple[message.Message, ...],
|
||||||
@ -720,23 +736,23 @@ class APIConnection:
|
|||||||
# This is safe because we are not awaiting between
|
# This is safe because we are not awaiting between
|
||||||
# sending the message and registering the handler
|
# sending the message and registering the handler
|
||||||
self.send_messages(messages)
|
self.send_messages(messages)
|
||||||
loop = self._loop
|
|
||||||
# Unsafe to await between sending the message and registering the handler
|
# Unsafe to await between sending the message and registering the handler
|
||||||
fut: asyncio.Future[None] = loop.create_future()
|
fut: asyncio.Future[None] = self._loop.create_future()
|
||||||
responses: list[message.Message] = []
|
responses: list[message.Message] = []
|
||||||
handler = self._handle_complex_message
|
on_message = partial(
|
||||||
on_message = partial(handler, fut, responses, do_append, do_stop)
|
_handle_complex_message, fut, responses, do_append, do_stop
|
||||||
|
)
|
||||||
read_exception_futures = self._read_exception_futures
|
|
||||||
self._add_message_callback_without_remove(on_message, msg_types)
|
self._add_message_callback_without_remove(on_message, msg_types)
|
||||||
|
|
||||||
read_exception_futures.add(fut)
|
self._read_exception_futures.add(fut)
|
||||||
# Now safe to await since we have registered the handler
|
# Now safe to await since we have registered the handler
|
||||||
|
|
||||||
# We must not await without a finally or
|
# We must not await without a finally or
|
||||||
# the message could fail to be removed if the
|
# the message could fail to be removed if the
|
||||||
# the await is cancelled
|
# the await is cancelled
|
||||||
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut)
|
timeout_handle = self._loop.call_at(
|
||||||
|
self._loop.time() + timeout, _handle_timeout, fut
|
||||||
|
)
|
||||||
timeout_expired = False
|
timeout_expired = False
|
||||||
try:
|
try:
|
||||||
await fut
|
await fut
|
||||||
@ -750,7 +766,7 @@ class APIConnection:
|
|||||||
if not timeout_expired:
|
if not timeout_expired:
|
||||||
timeout_handle.cancel()
|
timeout_handle.cancel()
|
||||||
self._remove_message_callback(on_message, msg_types)
|
self._remove_message_callback(on_message, msg_types)
|
||||||
read_exception_futures.discard(fut)
|
self._read_exception_futures.discard(fut)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
@ -775,7 +791,7 @@ class APIConnection:
|
|||||||
The connection will be closed, all exception handlers notified.
|
The connection will be closed, all exception handlers notified.
|
||||||
This method does not log the error, the call site should do so.
|
This method does not log the error, the call site should do so.
|
||||||
"""
|
"""
|
||||||
if not self._fatal_exception:
|
if self._fatal_exception is None:
|
||||||
if self._expected_disconnect is False:
|
if self._expected_disconnect is False:
|
||||||
# Only log the first error
|
# Only log the first error
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@ -810,7 +826,7 @@ class APIConnection:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = klass()
|
msg: message.Message = klass()
|
||||||
# MergeFromString instead of ParseFromString since
|
# MergeFromString instead of ParseFromString since
|
||||||
# ParseFromString will clear the message first and
|
# ParseFromString will clear the message first and
|
||||||
# the msg is already empty.
|
# the msg is already empty.
|
||||||
@ -876,14 +892,14 @@ class APIConnection:
|
|||||||
# the response if for some reason sending the response
|
# the response if for some reason sending the response
|
||||||
# fails we will still mark the disconnect as expected
|
# fails we will still mark the disconnect as expected
|
||||||
self._expected_disconnect = True
|
self._expected_disconnect = True
|
||||||
self.send_messages((DISCONNECT_RESPONSE_MESSAGE,))
|
self.send_messages(DISCONNECT_RESPONSE_MESSAGES)
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
||||||
self, _msg: PingRequest
|
self, _msg: PingRequest
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a PingRequest."""
|
"""Handle a PingRequest."""
|
||||||
self.send_messages((PING_RESPONSE_MESSAGE,))
|
self.send_messages(PING_RESPONSE_MESSAGES)
|
||||||
|
|
||||||
def _handle_get_time_request_internal( # pylint: disable=unused-argument
|
def _handle_get_time_request_internal( # pylint: disable=unused-argument
|
||||||
self, _msg: GetTimeRequest
|
self, _msg: GetTimeRequest
|
||||||
@ -895,7 +911,7 @@ class APIConnection:
|
|||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
"""Disconnect from the API."""
|
"""Disconnect from the API."""
|
||||||
if self._finish_connect_task:
|
if self._finish_connect_task is not None:
|
||||||
# Try to wait for the handshake to finish so we can send
|
# Try to wait for the handshake to finish so we can send
|
||||||
# a disconnect request. If it doesn't finish in time
|
# a disconnect request. If it doesn't finish in time
|
||||||
# we will just close the socket.
|
# we will just close the socket.
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
@ -26,12 +25,7 @@ from aioesphomeapi.core import (
|
|||||||
SocketClosedAPIError,
|
SocketClosedAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import (
|
from .common import get_mock_protocol, mock_data_received
|
||||||
async_fire_time_changed,
|
|
||||||
get_mock_protocol,
|
|
||||||
mock_data_received,
|
|
||||||
utcnow,
|
|
||||||
)
|
|
||||||
from .conftest import get_mock_connection_params
|
from .conftest import get_mock_connection_params
|
||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
@ -312,7 +306,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
|||||||
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -343,7 +337,7 @@ async def test_noise_frame_helper_incorrect_key():
|
|||||||
mock_data_received(helper, bytes.fromhex(pkt))
|
mock_data_received(helper, bytes.fromhex(pkt))
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -376,7 +370,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||||||
mock_data_received(helper, in_pkt[i : i + 1])
|
mock_data_received(helper, in_pkt[i : i + 1])
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -407,36 +401,7 @@ async def test_noise_incorrect_name():
|
|||||||
mock_data_received(helper, bytes.fromhex(pkt))
|
mock_data_received(helper, bytes.fromhex(pkt))
|
||||||
|
|
||||||
with pytest.raises(BadNameAPIError):
|
with pytest.raises(BadNameAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_noise_timeout():
|
|
||||||
"""Test we raise on bad name."""
|
|
||||||
outgoing_packets = [
|
|
||||||
"010000", # hello packet
|
|
||||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
|
||||||
]
|
|
||||||
|
|
||||||
connection, _ = _make_mock_connection()
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
|
||||||
connection=connection,
|
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
|
||||||
expected_name="wrongname",
|
|
||||||
client_info="my client",
|
|
||||||
log_name="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
|
||||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
|
||||||
|
|
||||||
task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
async_fire_time_changed(utcnow() + timedelta(seconds=60))
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
with pytest.raises(HandshakeAPIError):
|
|
||||||
await task
|
|
||||||
|
|
||||||
|
|
||||||
VARUINT_TESTCASES = [
|
VARUINT_TESTCASES = [
|
||||||
@ -478,7 +443,6 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||||
|
|
||||||
assert len(writes) == 1
|
assert len(writes) == 1
|
||||||
@ -502,7 +466,7 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
mock_data_received(helper, error_pkt_with_header)
|
mock_data_received(helper, error_pkt_with_header)
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||||
await handshake_task
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -528,7 +492,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||||
|
|
||||||
assert len(writes) == 1
|
assert len(writes) == 1
|
||||||
@ -546,7 +509,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
|
|
||||||
assert not writes
|
assert not writes
|
||||||
|
|
||||||
await handshake_task
|
await helper.ready_future
|
||||||
helper.write_packets([(1, b"to device")], True)
|
helper.write_packets([(1, b"to device")], True)
|
||||||
encrypted_packet = writes.pop()
|
encrypted_packet = writes.pop()
|
||||||
header = encrypted_packet[0:1]
|
header = encrypted_packet[0:1]
|
||||||
@ -591,7 +554,6 @@ async def test_noise_frame_helper_bad_encryption(
|
|||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||||
|
|
||||||
assert len(writes) == 1
|
assert len(writes) == 1
|
||||||
@ -609,7 +571,7 @@ async def test_noise_frame_helper_bad_encryption(
|
|||||||
|
|
||||||
assert not writes
|
assert not writes
|
||||||
|
|
||||||
await handshake_task
|
await helper.ready_future
|
||||||
helper.write_packets([(1, b"to device")], True)
|
helper.write_packets([(1, b"to device")], True)
|
||||||
encrypted_packet = writes.pop()
|
encrypted_packet = writes.pop()
|
||||||
header = encrypted_packet[0:1]
|
header = encrypted_packet[0:1]
|
||||||
@ -638,7 +600,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
|||||||
|
|
||||||
conn._socket = MagicMock()
|
conn._socket = MagicMock()
|
||||||
await conn._connect_init_frame_helper()
|
await conn._connect_init_frame_helper()
|
||||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||||
conn.connection_state = ConnectionState.CONNECTED
|
conn.connection_state = ConnectionState.CONNECTED
|
||||||
|
|
||||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||||
@ -687,13 +649,12 @@ async def test_noise_frame_helper_empty_hello():
|
|||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
hello_pkt_with_header = _make_noise_hello_pkt(b"")
|
hello_pkt_with_header = _make_noise_hello_pkt(b"")
|
||||||
|
|
||||||
mock_data_received(helper, hello_pkt_with_header)
|
mock_data_received(helper, hello_pkt_with_header)
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
||||||
await handshake_task
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -708,7 +669,6 @@ async def test_noise_frame_helper_wrong_protocol():
|
|||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
|
||||||
# wrong protocol 5 instead of 1
|
# wrong protocol 5 instead of 1
|
||||||
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
|
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
|
||||||
|
|
||||||
@ -717,7 +677,7 @@ async def test_noise_frame_helper_wrong_protocol():
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HandshakeAPIError, match="Unknown protocol selected by client 5"
|
HandshakeAPIError, match="Unknown protocol selected by client 5"
|
||||||
):
|
):
|
||||||
await handshake_task
|
await helper.ready_future
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Coroutine
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Callable, cast
|
||||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -161,7 +160,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
|||||||
|
|
||||||
conn._socket = MagicMock()
|
conn._socket = MagicMock()
|
||||||
await conn._connect_init_frame_helper()
|
await conn._connect_init_frame_helper()
|
||||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||||
conn.connection_state = ConnectionState.CONNECTED
|
conn.connection_state = ConnectionState.CONNECTED
|
||||||
|
|
||||||
with pytest.raises(RequiresEncryptionAPIError):
|
with pytest.raises(RequiresEncryptionAPIError):
|
||||||
@ -378,8 +377,18 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
|
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
|
||||||
"""Plaintext frame helper that raises exception on handshake."""
|
"""Plaintext frame helper that raises exception on handshake."""
|
||||||
|
|
||||||
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
|
def _create_failing_mock_transport_protocol(
|
||||||
raise exception
|
transport: asyncio.Transport,
|
||||||
|
connected: asyncio.Event,
|
||||||
|
create_func: Callable[[], APIPlaintextFrameHelper],
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
|
||||||
|
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
|
||||||
|
protocol._transport = cast(asyncio.Transport, transport)
|
||||||
|
protocol._writer = transport.write
|
||||||
|
protocol.ready_future.set_exception(exception)
|
||||||
|
connected.set()
|
||||||
|
return transport, protocol
|
||||||
|
|
||||||
def on_msg(msg):
|
def on_msg(msg):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
@ -393,7 +402,9 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
), patch.object(
|
), patch.object(
|
||||||
loop,
|
loop,
|
||||||
"create_connection",
|
"create_connection",
|
||||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
side_effect=partial(
|
||||||
|
_create_failing_mock_transport_protocol, transport, connected
|
||||||
|
),
|
||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
|
Loading…
Reference in New Issue
Block a user