mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Reduce duplicate code between connection and frame helper (#763)
This commit is contained in:
parent
5fb9c9243b
commit
1b51530642
@ -13,7 +13,7 @@ cdef class APIFrameHelper:
|
||||
cdef APIConnection _connection
|
||||
cdef object _transport
|
||||
cdef public object _writer
|
||||
cdef public object _ready_future
|
||||
cdef public object ready_future
|
||||
cdef bytes _buffer
|
||||
cdef unsigned int _buffer_len
|
||||
cdef unsigned int _pos
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Callable, cast
|
||||
|
||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||
from ..core import SocketClosedAPIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import APIConnection
|
||||
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
_float = float
|
||||
|
||||
|
||||
class APIFrameHelper:
|
||||
@ -33,7 +34,7 @@ class APIFrameHelper:
|
||||
"_connection",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_ready_future",
|
||||
"ready_future",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
@ -53,7 +54,7 @@ class APIFrameHelper:
|
||||
self._connection = connection
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._ready_future = self._loop.create_future()
|
||||
self.ready_future = self._loop.create_future()
|
||||
self._buffer: bytes | None = None
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
@ -65,8 +66,8 @@ class APIFrameHelper:
|
||||
self._log_name = log_name
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
if not self.ready_future.done():
|
||||
self.ready_future.set_exception(exc)
|
||||
|
||||
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Add data to the buffer."""
|
||||
@ -135,22 +136,6 @@ class APIFrameHelper:
|
||||
bitpos += 7
|
||||
return -1
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
handshake_handle = self._loop.call_at(
|
||||
self._loop.time() + timeout,
|
||||
self._set_ready_future_exception,
|
||||
asyncio.TimeoutError,
|
||||
)
|
||||
try:
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
|
||||
@abstractmethod
|
||||
def write_packets(
|
||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||
|
@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._ready_future.set_result(None)
|
||||
self.ready_future.set_result(None)
|
||||
|
||||
def write_packets(
|
||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||
|
@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
super().connection_made(transport)
|
||||
self._ready_future.set_result(None)
|
||||
self.ready_future.set_result(None)
|
||||
|
||||
def write_packets(
|
||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||
|
@ -71,7 +71,6 @@ from .api_pb2 import ( # type: ignore
|
||||
VoiceAssistantResponse,
|
||||
)
|
||||
from .client_callbacks import (
|
||||
handle_timeout,
|
||||
on_ble_raw_advertisement_response,
|
||||
on_bluetooth_connections_free_response,
|
||||
on_bluetooth_device_connection_response,
|
||||
@ -81,7 +80,7 @@ from .client_callbacks import (
|
||||
on_state_msg,
|
||||
on_subscribe_home_assistant_state_response,
|
||||
)
|
||||
from .connection import APIConnection, ConnectionParams
|
||||
from .connection import APIConnection, ConnectionParams, handle_timeout
|
||||
from .core import (
|
||||
APIConnectionError,
|
||||
BluetoothGATTAPIError,
|
||||
|
@ -8,5 +8,3 @@ cdef object CameraImageResponse, CameraState
|
||||
cdef object HomeassistantServiceCall
|
||||
|
||||
cdef object BluetoothLEAdvertisement
|
||||
|
||||
cdef object asyncio_TimeoutError
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from google.protobuf import message
|
||||
@ -98,12 +97,6 @@ def on_subscribe_home_assistant_state_response(
|
||||
on_state_sub(msg.entity_id, msg.attribute)
|
||||
|
||||
|
||||
def handle_timeout(fut: Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
|
||||
def on_bluetooth_device_connection_response(
|
||||
connect_future: Future[None],
|
||||
address: int,
|
||||
|
@ -9,19 +9,21 @@ cdef dict PROTO_TO_MESSAGE_TYPE
|
||||
cdef set OPEN_STATES
|
||||
|
||||
cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
||||
cdef object HANDSHAKE_TIMEOUT
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef object DISCONNECT_REQUEST_MESSAGE
|
||||
cdef object DISCONNECT_RESPONSE_MESSAGE
|
||||
cdef object PING_REQUEST_MESSAGE
|
||||
cdef object PING_RESPONSE_MESSAGE
|
||||
cdef tuple DISCONNECT_RESPONSE_MESSAGES
|
||||
cdef tuple PING_REQUEST_MESSAGES
|
||||
cdef tuple PING_RESPONSE_MESSAGES
|
||||
cdef object NO_PASSWORD_CONNECT_REQUEST
|
||||
|
||||
cdef object asyncio_timeout
|
||||
cdef object CancelledError
|
||||
cdef object asyncio_TimeoutError
|
||||
|
||||
cdef object ConnectResponse
|
||||
cdef object ConnectRequest, ConnectResponse
|
||||
cdef object DisconnectRequest
|
||||
cdef object PingRequest
|
||||
cdef object GetTimeRequest, GetTimeResponse
|
||||
@ -53,6 +55,20 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
|
||||
cdef object CONNECTION_STATE_CONNECTED
|
||||
cdef object CONNECTION_STATE_CLOSED
|
||||
|
||||
cdef object make_hello_request
|
||||
|
||||
cpdef handle_timeout(object fut)
|
||||
cpdef handle_complex_message(
|
||||
object fut,
|
||||
list responses,
|
||||
object do_append,
|
||||
object do_stop,
|
||||
object resp,
|
||||
)
|
||||
|
||||
cdef object _handle_timeout
|
||||
cdef object _handle_complex_message
|
||||
|
||||
@cython.dataclasses.dataclass
|
||||
cdef class ConnectionParams:
|
||||
cdef public str address
|
||||
@ -91,43 +107,45 @@ cdef class APIConnection:
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
|
||||
cpdef send_message(self, object msg)
|
||||
cpdef void send_message(self, object msg)
|
||||
|
||||
cdef send_messages(self, tuple messages)
|
||||
cdef void send_messages(self, tuple messages)
|
||||
|
||||
@cython.locals(handlers=set, handlers_copy=set)
|
||||
cpdef void process_packet(self, object msg_type_proto, object data)
|
||||
|
||||
cpdef _async_cancel_pong_timer(self)
|
||||
cdef void _async_cancel_pong_timer(self)
|
||||
|
||||
cpdef _async_schedule_keep_alive(self, object now)
|
||||
cdef void _async_schedule_keep_alive(self, object now)
|
||||
|
||||
cdef _cleanup(self)
|
||||
cdef void _cleanup(self)
|
||||
|
||||
cpdef set_log_name(self, str name)
|
||||
|
||||
cdef _make_connect_request(self)
|
||||
|
||||
cdef _process_hello_resp(self, object resp)
|
||||
cdef void _process_hello_resp(self, object resp)
|
||||
|
||||
cdef _process_login_response(self, object hello_response)
|
||||
cdef void _process_login_response(self, object hello_response)
|
||||
|
||||
cdef _set_connection_state(self, object state)
|
||||
cdef void _set_connection_state(self, object state)
|
||||
|
||||
cpdef report_fatal_error(self, Exception err)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
|
||||
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
cpdef _handle_disconnect_request_internal(self, object msg)
|
||||
cpdef void _handle_disconnect_request_internal(self, object msg)
|
||||
|
||||
cpdef _handle_ping_request_internal(self, object msg)
|
||||
cpdef void _handle_ping_request_internal(self, object msg)
|
||||
|
||||
cpdef _handle_get_time_request_internal(self, object msg)
|
||||
cpdef void _handle_get_time_request_internal(self, object msg)
|
||||
|
||||
cdef _set_fatal_exception_if_unset(self, Exception err)
|
||||
cdef void _set_fatal_exception_if_unset(self, Exception err)
|
||||
|
||||
cdef void _register_internal_message_handlers(self)
|
||||
|
@ -12,7 +12,7 @@ import time
|
||||
from asyncio import CancelledError
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from google.protobuf import message
|
||||
@ -63,9 +63,10 @@ _LOGGER = logging.getLogger(__name__)
|
||||
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
|
||||
|
||||
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
||||
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
|
||||
PING_REQUEST_MESSAGE = PingRequest()
|
||||
PING_RESPONSE_MESSAGE = PingResponse()
|
||||
DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
|
||||
PING_REQUEST_MESSAGES = (PingRequest(),)
|
||||
PING_RESPONSE_MESSAGES = (PingResponse(),)
|
||||
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
@ -131,6 +132,44 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
|
||||
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
|
||||
|
||||
|
||||
def _make_hello_request(client_info: str) -> HelloRequest:
|
||||
"""Make a HelloRequest."""
|
||||
return HelloRequest(
|
||||
client_info=client_info, api_version_major=1, api_version_minor=9
|
||||
)
|
||||
|
||||
|
||||
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
|
||||
make_hello_request = _cached_make_hello_request
|
||||
|
||||
|
||||
def handle_timeout(fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
|
||||
_handle_timeout = handle_timeout
|
||||
|
||||
|
||||
def handle_complex_message(
|
||||
fut: asyncio.Future[None],
|
||||
responses: list[message.Message],
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
resp: message.Message,
|
||||
) -> None:
|
||||
"""Handle a message that is part of a response."""
|
||||
if not fut.done():
|
||||
if do_append is None or do_append(resp):
|
||||
responses.append(resp)
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
|
||||
_handle_complex_message = handle_complex_message
|
||||
|
||||
|
||||
class APIConnection:
|
||||
"""This class represents _one_ connection to a remote native API device.
|
||||
|
||||
@ -331,12 +370,11 @@ class APIConnection:
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
|
||||
loop = self._loop
|
||||
if TYPE_CHECKING:
|
||||
assert self._socket is not None
|
||||
|
||||
if (noise_psk := self._params.noise_psk) is None:
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
_, fh = await self._loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APIPlaintextFrameHelper(
|
||||
connection=self,
|
||||
client_info=self._params.client_info,
|
||||
@ -345,7 +383,7 @@ class APIConnection:
|
||||
sock=self._socket,
|
||||
)
|
||||
else:
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
_, fh = await self._loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APINoiseFrameHelper(
|
||||
noise_psk=noise_psk,
|
||||
expected_name=self._params.expected_name,
|
||||
@ -359,24 +397,24 @@ class APIConnection:
|
||||
# Set the frame helper right away to ensure
|
||||
# the socket gets closed if we fail to handshake
|
||||
self._frame_helper = fh
|
||||
|
||||
handshake_handle = self._loop.call_at(
|
||||
self._loop.time() + HANDSHAKE_TIMEOUT,
|
||||
_handle_timeout,
|
||||
self._frame_helper.ready_future,
|
||||
)
|
||||
try:
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
await self._frame_helper.ready_future
|
||||
except asyncio_TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
|
||||
|
||||
async def _connect_hello_login(self, login: bool) -> None:
|
||||
"""Step 4 in connect process: send hello and login and get api version."""
|
||||
messages = [
|
||||
HelloRequest(
|
||||
client_info=self._params.client_info,
|
||||
api_version_major=1,
|
||||
api_version_minor=9,
|
||||
)
|
||||
]
|
||||
messages = [make_hello_request(self._params.client_info)]
|
||||
msg_types = [HelloResponse]
|
||||
if login:
|
||||
messages.append(self._make_connect_request())
|
||||
@ -447,16 +485,15 @@ class APIConnection:
|
||||
|
||||
def _async_send_keep_alive(self) -> None:
|
||||
"""Send a keep alive message."""
|
||||
loop = self._loop
|
||||
now = loop.time()
|
||||
now = self._loop.time()
|
||||
|
||||
if self._send_pending_ping:
|
||||
self.send_messages((PING_REQUEST_MESSAGE,))
|
||||
self.send_messages(PING_REQUEST_MESSAGES)
|
||||
if self._pong_timer is None:
|
||||
# Do not reset the timer if it's already set
|
||||
# since the only thing we want to reset the timer
|
||||
# is if we receive a pong.
|
||||
self._pong_timer = loop.call_at(
|
||||
self._pong_timer = self._loop.call_at(
|
||||
now + self._keep_alive_timeout, self._async_pong_not_received
|
||||
)
|
||||
elif self._debug_enabled:
|
||||
@ -600,10 +637,9 @@ class APIConnection:
|
||||
|
||||
def _make_connect_request(self) -> ConnectRequest:
|
||||
"""Make a ConnectRequest."""
|
||||
connect = ConnectRequest()
|
||||
if self._params.password is not None:
|
||||
connect.password = self._params.password
|
||||
return connect
|
||||
return ConnectRequest(password=self._params.password)
|
||||
return NO_PASSWORD_CONNECT_REQUEST
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a message to the remote."""
|
||||
@ -679,26 +715,6 @@ class APIConnection:
|
||||
# we register the handler after sending the message
|
||||
return self.add_message_callback(on_message, msg_types)
|
||||
|
||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
def _handle_complex_message(
|
||||
self,
|
||||
fut: asyncio.Future[None],
|
||||
responses: list[message.Message],
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
resp: message.Message,
|
||||
) -> None:
|
||||
"""Handle a message that is part of a response."""
|
||||
if not fut.done():
|
||||
if do_append is None or do_append(resp):
|
||||
responses.append(resp)
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
|
||||
self,
|
||||
messages: tuple[message.Message, ...],
|
||||
@ -720,23 +736,23 @@ class APIConnection:
|
||||
# This is safe because we are not awaiting between
|
||||
# sending the message and registering the handler
|
||||
self.send_messages(messages)
|
||||
loop = self._loop
|
||||
# Unsafe to await between sending the message and registering the handler
|
||||
fut: asyncio.Future[None] = loop.create_future()
|
||||
fut: asyncio.Future[None] = self._loop.create_future()
|
||||
responses: list[message.Message] = []
|
||||
handler = self._handle_complex_message
|
||||
on_message = partial(handler, fut, responses, do_append, do_stop)
|
||||
|
||||
read_exception_futures = self._read_exception_futures
|
||||
on_message = partial(
|
||||
_handle_complex_message, fut, responses, do_append, do_stop
|
||||
)
|
||||
self._add_message_callback_without_remove(on_message, msg_types)
|
||||
|
||||
read_exception_futures.add(fut)
|
||||
self._read_exception_futures.add(fut)
|
||||
# Now safe to await since we have registered the handler
|
||||
|
||||
# We must not await without a finally or
|
||||
# the message could fail to be removed if the
|
||||
# the await is cancelled
|
||||
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut)
|
||||
timeout_handle = self._loop.call_at(
|
||||
self._loop.time() + timeout, _handle_timeout, fut
|
||||
)
|
||||
timeout_expired = False
|
||||
try:
|
||||
await fut
|
||||
@ -750,7 +766,7 @@ class APIConnection:
|
||||
if not timeout_expired:
|
||||
timeout_handle.cancel()
|
||||
self._remove_message_callback(on_message, msg_types)
|
||||
read_exception_futures.discard(fut)
|
||||
self._read_exception_futures.discard(fut)
|
||||
|
||||
return responses
|
||||
|
||||
@ -775,7 +791,7 @@ class APIConnection:
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
if not self._fatal_exception:
|
||||
if self._fatal_exception is None:
|
||||
if self._expected_disconnect is False:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
@ -810,7 +826,7 @@ class APIConnection:
|
||||
return
|
||||
|
||||
try:
|
||||
msg = klass()
|
||||
msg: message.Message = klass()
|
||||
# MergeFromString instead of ParseFromString since
|
||||
# ParseFromString will clear the message first and
|
||||
# the msg is already empty.
|
||||
@ -876,14 +892,14 @@ class APIConnection:
|
||||
# the response if for some reason sending the response
|
||||
# fails we will still mark the disconnect as expected
|
||||
self._expected_disconnect = True
|
||||
self.send_messages((DISCONNECT_RESPONSE_MESSAGE,))
|
||||
self.send_messages(DISCONNECT_RESPONSE_MESSAGES)
|
||||
self._cleanup()
|
||||
|
||||
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
||||
self, _msg: PingRequest
|
||||
) -> None:
|
||||
"""Handle a PingRequest."""
|
||||
self.send_messages((PING_RESPONSE_MESSAGE,))
|
||||
self.send_messages(PING_RESPONSE_MESSAGES)
|
||||
|
||||
def _handle_get_time_request_internal( # pylint: disable=unused-argument
|
||||
self, _msg: GetTimeRequest
|
||||
@ -895,7 +911,7 @@ class APIConnection:
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the API."""
|
||||
if self._finish_connect_task:
|
||||
if self._finish_connect_task is not None:
|
||||
# Try to wait for the handshake to finish so we can send
|
||||
# a disconnect request. If it doesn't finish in time
|
||||
# we will just close the socket.
|
||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -26,12 +25,7 @@ from aioesphomeapi.core import (
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
async_fire_time_changed,
|
||||
get_mock_protocol,
|
||||
mock_data_received,
|
||||
utcnow,
|
||||
)
|
||||
from .common import get_mock_protocol, mock_data_received
|
||||
from .conftest import get_mock_connection_params
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
@ -312,7 +306,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
||||
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -343,7 +337,7 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -376,7 +370,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
mock_data_received(helper, in_pkt[i : i + 1])
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -407,36 +401,7 @@ async def test_noise_incorrect_name():
|
||||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_timeout():
|
||||
"""Test we raise on bad name."""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
connection=connection,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0)
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=60))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(HandshakeAPIError):
|
||||
await task
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
VARUINT_TESTCASES = [
|
||||
@ -478,7 +443,6 @@ async def test_noise_frame_helper_handshake_failure():
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
@ -502,7 +466,7 @@ async def test_noise_frame_helper_handshake_failure():
|
||||
mock_data_received(helper, error_pkt_with_header)
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||
await handshake_task
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -528,7 +492,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
@ -546,7 +509,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
|
||||
assert not writes
|
||||
|
||||
await handshake_task
|
||||
await helper.ready_future
|
||||
helper.write_packets([(1, b"to device")], True)
|
||||
encrypted_packet = writes.pop()
|
||||
header = encrypted_packet[0:1]
|
||||
@ -591,7 +554,6 @@ async def test_noise_frame_helper_bad_encryption(
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
@ -609,7 +571,7 @@ async def test_noise_frame_helper_bad_encryption(
|
||||
|
||||
assert not writes
|
||||
|
||||
await handshake_task
|
||||
await helper.ready_future
|
||||
helper.write_packets([(1, b"to device")], True)
|
||||
encrypted_packet = writes.pop()
|
||||
header = encrypted_packet[0:1]
|
||||
@ -638,7 +600,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
||||
|
||||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
@ -687,13 +649,12 @@ async def test_noise_frame_helper_empty_hello():
|
||||
log_name="test",
|
||||
)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
hello_pkt_with_header = _make_noise_hello_pkt(b"")
|
||||
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
||||
await handshake_task
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -708,7 +669,6 @@ async def test_noise_frame_helper_wrong_protocol():
|
||||
log_name="test",
|
||||
)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
# wrong protocol 5 instead of 1
|
||||
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
|
||||
|
||||
@ -717,7 +677,7 @@ async def test_noise_frame_helper_wrong_protocol():
|
||||
with pytest.raises(
|
||||
HandshakeAPIError, match="Unknown protocol selected by client 5"
|
||||
):
|
||||
await handshake_task
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Callable, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
@ -161,7 +160,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
|
||||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
@ -378,8 +377,18 @@ async def test_plaintext_connection_fails_handshake(
|
||||
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
|
||||
"""Plaintext frame helper that raises exception on handshake."""
|
||||
|
||||
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
|
||||
raise exception
|
||||
def _create_failing_mock_transport_protocol(
|
||||
transport: asyncio.Transport,
|
||||
connected: asyncio.Event,
|
||||
create_func: Callable[[], APIPlaintextFrameHelper],
|
||||
**kwargs,
|
||||
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
|
||||
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
|
||||
protocol._transport = cast(asyncio.Transport, transport)
|
||||
protocol._writer = transport.write
|
||||
protocol.ready_future.set_exception(exception)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
@ -393,7 +402,9 @@ async def test_plaintext_connection_fails_handshake(
|
||||
), patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
side_effect=partial(
|
||||
_create_failing_mock_transport_protocol, transport, connected
|
||||
),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
Loading…
Reference in New Issue
Block a user