Reduce duplicate code between connection and frame helper (#763)

This commit is contained in:
J. Nick Koston 2023-11-27 23:51:38 -06:00 committed by GitHub
parent 5fb9c9243b
commit 1b51530642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 147 additions and 167 deletions

View File

@ -13,7 +13,7 @@ cdef class APIFrameHelper:
cdef APIConnection _connection cdef APIConnection _connection
cdef object _transport cdef object _transport
cdef public object _writer cdef public object _writer
cdef public object _ready_future cdef public object ready_future
cdef bytes _buffer cdef bytes _buffer
cdef unsigned int _buffer_len cdef unsigned int _buffer_len
cdef unsigned int _pos cdef unsigned int _pos

View File

@ -5,7 +5,7 @@ import logging
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, cast from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError from ..core import SocketClosedAPIError
if TYPE_CHECKING: if TYPE_CHECKING:
from ..connection import APIConnection from ..connection import APIConnection
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
_int = int _int = int
_bytes = bytes _bytes = bytes
_float = float
class APIFrameHelper: class APIFrameHelper:
@ -33,7 +34,7 @@ class APIFrameHelper:
"_connection", "_connection",
"_transport", "_transport",
"_writer", "_writer",
"_ready_future", "ready_future",
"_buffer", "_buffer",
"_buffer_len", "_buffer_len",
"_pos", "_pos",
@ -53,7 +54,7 @@ class APIFrameHelper:
self._connection = connection self._connection = connection
self._transport: asyncio.Transport | None = None self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future() self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None self._buffer: bytes | None = None
self._buffer_len = 0 self._buffer_len = 0
self._pos = 0 self._pos = 0
@ -65,8 +66,8 @@ class APIFrameHelper:
self._log_name = log_name self._log_name = log_name
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None: def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done(): if not self.ready_future.done():
self._ready_future.set_exception(exc) self.ready_future.set_exception(exc)
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None: def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
"""Add data to the buffer.""" """Add data to the buffer."""
@ -135,22 +136,6 @@ class APIFrameHelper:
bitpos += 7 bitpos += 7
return -1 return -1
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
handshake_handle = self._loop.call_at(
self._loop.time() + timeout,
self._set_ready_future_exception,
asyncio.TimeoutError,
)
try:
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
@abstractmethod @abstractmethod
def write_packets( def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None, None,
) )
self._ready_future.set_result(None) self.ready_future.set_result(None)
def write_packets( def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection.""" """Handle a new connection."""
super().connection_made(transport) super().connection_made(transport)
self._ready_future.set_result(None) self.ready_future.set_result(None)
def write_packets( def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -71,7 +71,6 @@ from .api_pb2 import ( # type: ignore
VoiceAssistantResponse, VoiceAssistantResponse,
) )
from .client_callbacks import ( from .client_callbacks import (
handle_timeout,
on_ble_raw_advertisement_response, on_ble_raw_advertisement_response,
on_bluetooth_connections_free_response, on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response, on_bluetooth_device_connection_response,
@ -81,7 +80,7 @@ from .client_callbacks import (
on_state_msg, on_state_msg,
on_subscribe_home_assistant_state_response, on_subscribe_home_assistant_state_response,
) )
from .connection import APIConnection, ConnectionParams from .connection import APIConnection, ConnectionParams, handle_timeout
from .core import ( from .core import (
APIConnectionError, APIConnectionError,
BluetoothGATTAPIError, BluetoothGATTAPIError,

View File

@ -8,5 +8,3 @@ cdef object CameraImageResponse, CameraState
cdef object HomeassistantServiceCall cdef object HomeassistantServiceCall
cdef object BluetoothLEAdvertisement cdef object BluetoothLEAdvertisement
cdef object asyncio_TimeoutError

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from asyncio import Future from asyncio import Future
from asyncio import TimeoutError as asyncio_TimeoutError
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable
from google.protobuf import message from google.protobuf import message
@ -98,12 +97,6 @@ def on_subscribe_home_assistant_state_response(
on_state_sub(msg.entity_id, msg.attribute) on_state_sub(msg.entity_id, msg.attribute)
def handle_timeout(fut: Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def on_bluetooth_device_connection_response( def on_bluetooth_device_connection_response(
connect_future: Future[None], connect_future: Future[None],
address: int, address: int,

View File

@ -9,19 +9,21 @@ cdef dict PROTO_TO_MESSAGE_TYPE
cdef set OPEN_STATES cdef set OPEN_STATES
cdef float KEEP_ALIVE_TIMEOUT_RATIO cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef object HANDSHAKE_TIMEOUT
cdef bint TYPE_CHECKING cdef bint TYPE_CHECKING
cdef object DISCONNECT_REQUEST_MESSAGE cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE cdef tuple DISCONNECT_RESPONSE_MESSAGES
cdef object PING_REQUEST_MESSAGE cdef tuple PING_REQUEST_MESSAGES
cdef object PING_RESPONSE_MESSAGE cdef tuple PING_RESPONSE_MESSAGES
cdef object NO_PASSWORD_CONNECT_REQUEST
cdef object asyncio_timeout cdef object asyncio_timeout
cdef object CancelledError cdef object CancelledError
cdef object asyncio_TimeoutError cdef object asyncio_TimeoutError
cdef object ConnectResponse cdef object ConnectRequest, ConnectResponse
cdef object DisconnectRequest cdef object DisconnectRequest
cdef object PingRequest cdef object PingRequest
cdef object GetTimeRequest, GetTimeResponse cdef object GetTimeRequest, GetTimeResponse
@ -53,6 +55,20 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
cdef object CONNECTION_STATE_CONNECTED cdef object CONNECTION_STATE_CONNECTED
cdef object CONNECTION_STATE_CLOSED cdef object CONNECTION_STATE_CLOSED
cdef object make_hello_request
cpdef handle_timeout(object fut)
cpdef handle_complex_message(
object fut,
list responses,
object do_append,
object do_stop,
object resp,
)
cdef object _handle_timeout
cdef object _handle_complex_message
@cython.dataclasses.dataclass @cython.dataclasses.dataclass
cdef class ConnectionParams: cdef class ConnectionParams:
cdef public str address cdef public str address
@ -91,43 +107,45 @@ cdef class APIConnection:
cdef public str received_name cdef public str received_name
cdef public object resolved_addr_info cdef public object resolved_addr_info
cpdef send_message(self, object msg) cpdef void send_message(self, object msg)
cdef send_messages(self, tuple messages) cdef void send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set) @cython.locals(handlers=set, handlers_copy=set)
cpdef void process_packet(self, object msg_type_proto, object data) cpdef void process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self) cdef void _async_cancel_pong_timer(self)
cpdef _async_schedule_keep_alive(self, object now) cdef void _async_schedule_keep_alive(self, object now)
cdef _cleanup(self) cdef void _cleanup(self)
cpdef set_log_name(self, str name) cpdef set_log_name(self, str name)
cdef _make_connect_request(self) cdef _make_connect_request(self)
cdef _process_hello_resp(self, object resp) cdef void _process_hello_resp(self, object resp)
cdef _process_login_response(self, object hello_response) cdef void _process_login_response(self, object hello_response)
cdef _set_connection_state(self, object state) cdef void _set_connection_state(self, object state)
cpdef report_fatal_error(self, Exception err) cpdef report_fatal_error(self, Exception err)
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cpdef add_message_callback(self, object on_message, tuple msg_types) cpdef add_message_callback(self, object on_message, tuple msg_types)
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types) cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
cpdef _handle_disconnect_request_internal(self, object msg) cpdef void _handle_disconnect_request_internal(self, object msg)
cpdef _handle_ping_request_internal(self, object msg) cpdef void _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg) cpdef void _handle_get_time_request_internal(self, object msg)
cdef _set_fatal_exception_if_unset(self, Exception err) cdef void _set_fatal_exception_if_unset(self, Exception err)
cdef void _register_internal_message_handlers(self)

View File

@ -12,7 +12,7 @@ import time
from asyncio import CancelledError from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError from asyncio import TimeoutError as asyncio_TimeoutError
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from functools import partial from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
from google.protobuf import message from google.protobuf import message
@ -63,9 +63,10 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse() DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
PING_REQUEST_MESSAGE = PingRequest() PING_REQUEST_MESSAGES = (PingRequest(),)
PING_RESPONSE_MESSAGE = PingResponse() PING_RESPONSE_MESSAGES = (PingResponse(),)
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@ -131,6 +132,44 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
def _make_hello_request(client_info: str) -> HelloRequest:
"""Make a HelloRequest."""
return HelloRequest(
client_info=client_info, api_version_major=1, api_version_minor=9
)
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
make_hello_request = _cached_make_hello_request
def handle_timeout(fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
_handle_timeout = handle_timeout
def handle_complex_message(
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
_handle_complex_message = handle_complex_message
class APIConnection: class APIConnection:
"""This class represents _one_ connection to a remote native API device. """This class represents _one_ connection to a remote native API device.
@ -331,12 +370,11 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None: async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._socket is not None assert self._socket is not None
if (noise_psk := self._params.noise_psk) is None: if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var] _, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
connection=self, connection=self,
client_info=self._params.client_info, client_info=self._params.client_info,
@ -345,7 +383,7 @@ class APIConnection:
sock=self._socket, sock=self._socket,
) )
else: else:
_, fh = await loop.create_connection( # type: ignore[type-var] _, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper( lambda: APINoiseFrameHelper(
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name=self._params.expected_name, expected_name=self._params.expected_name,
@ -359,24 +397,24 @@ class APIConnection:
# Set the frame helper right away to ensure # Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake # the socket gets closed if we fail to handshake
self._frame_helper = fh self._frame_helper = fh
handshake_handle = self._loop.call_at(
self._loop.time() + HANDSHAKE_TIMEOUT,
_handle_timeout,
self._frame_helper.ready_future,
)
try: try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT) await self._frame_helper.ready_future
except asyncio_TimeoutError as err: except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err raise TimeoutAPIError("Handshake timed out") from err
except OSError as err: except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err raise HandshakeAPIError(f"Handshake failed: {err}") from err
finally:
handshake_handle.cancel()
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE) self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
async def _connect_hello_login(self, login: bool) -> None: async def _connect_hello_login(self, login: bool) -> None:
"""Step 4 in connect process: send hello and login and get api version.""" """Step 4 in connect process: send hello and login and get api version."""
messages = [ messages = [make_hello_request(self._params.client_info)]
HelloRequest(
client_info=self._params.client_info,
api_version_major=1,
api_version_minor=9,
)
]
msg_types = [HelloResponse] msg_types = [HelloResponse]
if login: if login:
messages.append(self._make_connect_request()) messages.append(self._make_connect_request())
@ -447,16 +485,15 @@ class APIConnection:
def _async_send_keep_alive(self) -> None: def _async_send_keep_alive(self) -> None:
"""Send a keep alive message.""" """Send a keep alive message."""
loop = self._loop now = self._loop.time()
now = loop.time()
if self._send_pending_ping: if self._send_pending_ping:
self.send_messages((PING_REQUEST_MESSAGE,)) self.send_messages(PING_REQUEST_MESSAGES)
if self._pong_timer is None: if self._pong_timer is None:
# Do not reset the timer if it's already set # Do not reset the timer if it's already set
# since the only thing we want to reset the timer # since the only thing we want to reset the timer
# is if we receive a pong. # is if we receive a pong.
self._pong_timer = loop.call_at( self._pong_timer = self._loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received now + self._keep_alive_timeout, self._async_pong_not_received
) )
elif self._debug_enabled: elif self._debug_enabled:
@ -600,10 +637,9 @@ class APIConnection:
def _make_connect_request(self) -> ConnectRequest: def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest.""" """Make a ConnectRequest."""
connect = ConnectRequest()
if self._params.password is not None: if self._params.password is not None:
connect.password = self._params.password return ConnectRequest(password=self._params.password)
return connect return NO_PASSWORD_CONNECT_REQUEST
def send_message(self, msg: message.Message) -> None: def send_message(self, msg: message.Message) -> None:
"""Send a message to the remote.""" """Send a message to the remote."""
@ -679,26 +715,6 @@ class APIConnection:
# we register the handler after sending the message # we register the handler after sending the message
return self.add_message_callback(on_message, msg_types) return self.add_message_callback(on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def _handle_complex_message(
self,
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
async def send_messages_await_response_complex( # pylint: disable=too-many-locals async def send_messages_await_response_complex( # pylint: disable=too-many-locals
self, self,
messages: tuple[message.Message, ...], messages: tuple[message.Message, ...],
@ -720,23 +736,23 @@ class APIConnection:
# This is safe because we are not awaiting between # This is safe because we are not awaiting between
# sending the message and registering the handler # sending the message and registering the handler
self.send_messages(messages) self.send_messages(messages)
loop = self._loop
# Unsafe to await between sending the message and registering the handler # Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future() fut: asyncio.Future[None] = self._loop.create_future()
responses: list[message.Message] = [] responses: list[message.Message] = []
handler = self._handle_complex_message on_message = partial(
on_message = partial(handler, fut, responses, do_append, do_stop) _handle_complex_message, fut, responses, do_append, do_stop
)
read_exception_futures = self._read_exception_futures
self._add_message_callback_without_remove(on_message, msg_types) self._add_message_callback_without_remove(on_message, msg_types)
read_exception_futures.add(fut) self._read_exception_futures.add(fut)
# Now safe to await since we have registered the handler # Now safe to await since we have registered the handler
# We must not await without a finally or # We must not await without a finally or
# the message could fail to be removed if the # the message could fail to be removed if the
# the await is cancelled # the await is cancelled
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut) timeout_handle = self._loop.call_at(
self._loop.time() + timeout, _handle_timeout, fut
)
timeout_expired = False timeout_expired = False
try: try:
await fut await fut
@ -750,7 +766,7 @@ class APIConnection:
if not timeout_expired: if not timeout_expired:
timeout_handle.cancel() timeout_handle.cancel()
self._remove_message_callback(on_message, msg_types) self._remove_message_callback(on_message, msg_types)
read_exception_futures.discard(fut) self._read_exception_futures.discard(fut)
return responses return responses
@ -775,7 +791,7 @@ class APIConnection:
The connection will be closed, all exception handlers notified. The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so. This method does not log the error, the call site should do so.
""" """
if not self._fatal_exception: if self._fatal_exception is None:
if self._expected_disconnect is False: if self._expected_disconnect is False:
# Only log the first error # Only log the first error
_LOGGER.warning( _LOGGER.warning(
@ -810,7 +826,7 @@ class APIConnection:
return return
try: try:
msg = klass() msg: message.Message = klass()
# MergeFromString instead of ParseFromString since # MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and # ParseFromString will clear the message first and
# the msg is already empty. # the msg is already empty.
@ -876,14 +892,14 @@ class APIConnection:
# the response if for some reason sending the response # the response if for some reason sending the response
# fails we will still mark the disconnect as expected # fails we will still mark the disconnect as expected
self._expected_disconnect = True self._expected_disconnect = True
self.send_messages((DISCONNECT_RESPONSE_MESSAGE,)) self.send_messages(DISCONNECT_RESPONSE_MESSAGES)
self._cleanup() self._cleanup()
def _handle_ping_request_internal( # pylint: disable=unused-argument def _handle_ping_request_internal( # pylint: disable=unused-argument
self, _msg: PingRequest self, _msg: PingRequest
) -> None: ) -> None:
"""Handle a PingRequest.""" """Handle a PingRequest."""
self.send_messages((PING_RESPONSE_MESSAGE,)) self.send_messages(PING_RESPONSE_MESSAGES)
def _handle_get_time_request_internal( # pylint: disable=unused-argument def _handle_get_time_request_internal( # pylint: disable=unused-argument
self, _msg: GetTimeRequest self, _msg: GetTimeRequest
@ -895,7 +911,7 @@ class APIConnection:
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from the API.""" """Disconnect from the API."""
if self._finish_connect_task: if self._finish_connect_task is not None:
# Try to wait for the handshake to finish so we can send # Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time # a disconnect request. If it doesn't finish in time
# we will just close the socket. # we will just close the socket.

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -26,12 +25,7 @@ from aioesphomeapi.core import (
SocketClosedAPIError, SocketClosedAPIError,
) )
from .common import ( from .common import get_mock_protocol, mock_data_received
async_fire_time_changed,
get_mock_protocol,
mock_data_received,
utcnow,
)
from .conftest import get_mock_connection_params from .conftest import get_mock_connection_params
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -312,7 +306,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
mock_data_received(helper, byte_type(bytes.fromhex(pkt))) mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio
@ -343,7 +337,7 @@ async def test_noise_frame_helper_incorrect_key():
mock_data_received(helper, bytes.fromhex(pkt)) mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio
@ -376,7 +370,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
mock_data_received(helper, in_pkt[i : i + 1]) mock_data_received(helper, in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio
@ -407,36 +401,7 @@ async def test_noise_incorrect_name():
mock_data_received(helper, bytes.fromhex(pkt)) mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError): with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30) await helper.ready_future
@pytest.mark.asyncio
async def test_noise_timeout():
"""Test we raise on bad name."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
log_name="test",
)
for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt))
task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=60))
await asyncio.sleep(0)
with pytest.raises(HandshakeAPIError):
await task
VARUINT_TESTCASES = [ VARUINT_TESTCASES = [
@ -478,7 +443,6 @@ async def test_noise_frame_helper_handshake_failure():
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1 assert len(writes) == 1
@ -502,7 +466,7 @@ async def test_noise_frame_helper_handshake_failure():
mock_data_received(helper, error_pkt_with_header) mock_data_received(helper, error_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="forced to fail"): with pytest.raises(HandshakeAPIError, match="forced to fail"):
await handshake_task await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio
@ -528,7 +492,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1 assert len(writes) == 1
@ -546,7 +509,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert not writes assert not writes
await handshake_task await helper.ready_future
helper.write_packets([(1, b"to device")], True) helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop() encrypted_packet = writes.pop()
header = encrypted_packet[0:1] header = encrypted_packet[0:1]
@ -591,7 +554,6 @@ async def test_noise_frame_helper_bad_encryption(
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1 assert len(writes) == 1
@ -609,7 +571,7 @@ async def test_noise_frame_helper_bad_encryption(
assert not writes assert not writes
await handshake_task await helper.ready_future
helper.write_packets([(1, b"to device")], True) helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop() encrypted_packet = writes.pop()
header = encrypted_packet[0:1] header = encrypted_packet[0:1]
@ -638,7 +600,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
conn._socket = MagicMock() conn._socket = MagicMock()
await conn._connect_init_frame_helper() await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None) loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED conn.connection_state = ConnectionState.CONNECTED
task = asyncio.create_task(conn._connect_hello_login(login=True)) task = asyncio.create_task(conn._connect_hello_login(login=True))
@ -687,13 +649,12 @@ async def test_noise_frame_helper_empty_hello():
log_name="test", log_name="test",
) )
handshake_task = asyncio.create_task(helper.perform_handshake(30))
hello_pkt_with_header = _make_noise_hello_pkt(b"") hello_pkt_with_header = _make_noise_hello_pkt(b"")
mock_data_received(helper, hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
await handshake_task await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio
@ -708,7 +669,6 @@ async def test_noise_frame_helper_wrong_protocol():
log_name="test", log_name="test",
) )
handshake_task = asyncio.create_task(helper.perform_handshake(30))
# wrong protocol 5 instead of 1 # wrong protocol 5 instead of 1
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0") hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
@ -717,7 +677,7 @@ async def test_noise_frame_helper_wrong_protocol():
with pytest.raises( with pytest.raises(
HandshakeAPIError, match="Unknown protocol selected by client 5" HandshakeAPIError, match="Unknown protocol selected by client 5"
): ):
await handshake_task await helper.ready_future
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -2,10 +2,9 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections.abc import Coroutine
from datetime import timedelta from datetime import timedelta
from functools import partial from functools import partial
from typing import Any from typing import Callable, cast
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest import pytest
@ -161,7 +160,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
conn._socket = MagicMock() conn._socket = MagicMock()
await conn._connect_init_frame_helper() await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None) loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED conn.connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError): with pytest.raises(RequiresEncryptionAPIError):
@ -378,8 +377,18 @@ async def test_plaintext_connection_fails_handshake(
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper): class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
"""Plaintext frame helper that raises exception on handshake.""" """Plaintext frame helper that raises exception on handshake."""
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]: def _create_failing_mock_transport_protocol(
raise exception transport: asyncio.Transport,
connected: asyncio.Event,
create_func: Callable[[], APIPlaintextFrameHelper],
**kwargs,
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
protocol._transport = cast(asyncio.Transport, transport)
protocol._writer = transport.write
protocol.ready_future.set_exception(exception)
connected.set()
return transport, protocol
def on_msg(msg): def on_msg(msg):
messages.append(msg) messages.append(msg)
@ -393,7 +402,9 @@ async def test_plaintext_connection_fails_handshake(
), patch.object( ), patch.object(
loop, loop,
"create_connection", "create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected), side_effect=partial(
_create_failing_mock_transport_protocol, transport, connected
),
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()