diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 2f16b11..9b5a389 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -13,7 +13,7 @@ cdef class APIFrameHelper: cdef APIConnection _connection cdef object _transport cdef public object _writer - cdef public object _ready_future + cdef public object ready_future cdef bytes _buffer cdef unsigned int _buffer_len cdef unsigned int _pos diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 597f41b..a63230e 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -5,7 +5,7 @@ import logging from abc import abstractmethod from typing import TYPE_CHECKING, Callable, cast -from ..core import HandshakeAPIError, SocketClosedAPIError +from ..core import SocketClosedAPIError if TYPE_CHECKING: from ..connection import APIConnection @@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError) _int = int _bytes = bytes +_float = float class APIFrameHelper: @@ -33,7 +34,7 @@ class APIFrameHelper: "_connection", "_transport", "_writer", - "_ready_future", + "ready_future", "_buffer", "_buffer_len", "_pos", @@ -53,7 +54,7 @@ class APIFrameHelper: self._connection = connection self._transport: asyncio.Transport | None = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None - self._ready_future = self._loop.create_future() + self.ready_future = self._loop.create_future() self._buffer: bytes | None = None self._buffer_len = 0 self._pos = 0 @@ -65,8 +66,8 @@ class APIFrameHelper: self._log_name = log_name def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None: - if not self._ready_future.done(): - self._ready_future.set_exception(exc) + if not self.ready_future.done(): + self.ready_future.set_exception(exc) def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None: """Add data to the buffer.""" @@ -135,22 +136,6 @@ class APIFrameHelper: bitpos += 7 return -1 - async def perform_handshake(self, timeout: float) -> None: - """Perform the handshake with the server.""" - handshake_handle = self._loop.call_at( - self._loop.time() + timeout, - self._set_ready_future_exception, - asyncio.TimeoutError, - ) - try: - await self._ready_future - except asyncio.TimeoutError as err: - raise HandshakeAPIError( - f"{self._log_name}: Timeout during handshake" - ) from err - finally: - handshake_handle.cancel() - @abstractmethod def write_packets( self, packets: list[tuple[int, bytes]], debug_enabled: bool diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 41ea06f..f633d34 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper): noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member None, ) - self._ready_future.set_result(None) + self.ready_future.set_result(None) def write_packets( self, packets: list[tuple[int, bytes]], debug_enabled: bool diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 7be442b..8d8cc9e 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a new connection.""" super().connection_made(transport) - self._ready_future.set_result(None) + self.ready_future.set_result(None) def write_packets( self, packets: list[tuple[int, bytes]], debug_enabled: bool diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 6f06b60..7117fda 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -71,7 +71,6 @@ from .api_pb2 import ( # type: ignore VoiceAssistantResponse, ) from .client_callbacks import ( - handle_timeout, on_ble_raw_advertisement_response, on_bluetooth_connections_free_response, on_bluetooth_device_connection_response, @@ -81,7 +80,7 @@ from .client_callbacks import ( on_state_msg, on_subscribe_home_assistant_state_response, ) -from .connection import APIConnection, ConnectionParams +from .connection import APIConnection, ConnectionParams, handle_timeout from .core import ( APIConnectionError, BluetoothGATTAPIError, diff --git a/aioesphomeapi/client_callbacks.pxd b/aioesphomeapi/client_callbacks.pxd index 8fad3d9..2611a2a 100644 --- a/aioesphomeapi/client_callbacks.pxd +++ b/aioesphomeapi/client_callbacks.pxd @@ -8,5 +8,3 @@ cdef object CameraImageResponse, CameraState cdef object HomeassistantServiceCall cdef object BluetoothLEAdvertisement - -cdef object asyncio_TimeoutError diff --git a/aioesphomeapi/client_callbacks.py b/aioesphomeapi/client_callbacks.py index c846c86..7b250e3 100644 --- a/aioesphomeapi/client_callbacks.py +++ b/aioesphomeapi/client_callbacks.py @@ -1,7 +1,6 @@ from __future__ import annotations from asyncio import Future -from asyncio import TimeoutError as asyncio_TimeoutError from typing import TYPE_CHECKING, Callable from google.protobuf import message @@ -98,12 +97,6 @@ def on_subscribe_home_assistant_state_response( on_state_sub(msg.entity_id, msg.attribute) -def handle_timeout(fut: Future[None]) -> None: - """Handle a timeout.""" - if not fut.done(): - fut.set_exception(asyncio_TimeoutError) - - def on_bluetooth_device_connection_response( connect_future: Future[None], address: int, diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 9d420ac..110b919 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -9,19 +9,21 @@ cdef dict PROTO_TO_MESSAGE_TYPE cdef set OPEN_STATES cdef float KEEP_ALIVE_TIMEOUT_RATIO +cdef object HANDSHAKE_TIMEOUT cdef bint TYPE_CHECKING cdef object DISCONNECT_REQUEST_MESSAGE -cdef object DISCONNECT_RESPONSE_MESSAGE -cdef object PING_REQUEST_MESSAGE -cdef object PING_RESPONSE_MESSAGE +cdef tuple DISCONNECT_RESPONSE_MESSAGES +cdef tuple PING_REQUEST_MESSAGES +cdef tuple PING_RESPONSE_MESSAGES +cdef object NO_PASSWORD_CONNECT_REQUEST cdef object asyncio_timeout cdef object CancelledError cdef object asyncio_TimeoutError -cdef object ConnectResponse +cdef object ConnectRequest, ConnectResponse cdef object DisconnectRequest cdef object PingRequest cdef object GetTimeRequest, GetTimeResponse @@ -53,6 +55,20 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE cdef object CONNECTION_STATE_CONNECTED cdef object CONNECTION_STATE_CLOSED +cdef object make_hello_request + +cpdef handle_timeout(object fut) +cpdef handle_complex_message( + object fut, + list responses, + object do_append, + object do_stop, + object resp, +) + +cdef object _handle_timeout +cdef object _handle_complex_message + @cython.dataclasses.dataclass cdef class ConnectionParams: cdef public str address @@ -91,43 +107,45 @@ cdef class APIConnection: cdef public str received_name cdef public object resolved_addr_info - cpdef send_message(self, object msg) + cpdef void send_message(self, object msg) - cdef send_messages(self, tuple messages) + cdef void send_messages(self, tuple messages) @cython.locals(handlers=set, handlers_copy=set) cpdef void process_packet(self, object msg_type_proto, object data) - cpdef _async_cancel_pong_timer(self) + cdef void _async_cancel_pong_timer(self) - cpdef _async_schedule_keep_alive(self, object now) + cdef void _async_schedule_keep_alive(self, object now) - cdef _cleanup(self) + cdef void _cleanup(self) cpdef set_log_name(self, str name) cdef _make_connect_request(self) - cdef _process_hello_resp(self, object resp) + cdef void _process_hello_resp(self, object resp) - cdef _process_login_response(self, object hello_response) + cdef void _process_login_response(self, object hello_response) - cdef _set_connection_state(self, object state) + cdef void _set_connection_state(self, object state) cpdef report_fatal_error(self, Exception err) @cython.locals(handlers=set) - cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) + cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types) cpdef add_message_callback(self, object on_message, tuple msg_types) @cython.locals(handlers=set) - cpdef _remove_message_callback(self, object on_message, tuple msg_types) + cpdef void _remove_message_callback(self, object on_message, tuple msg_types) - cpdef _handle_disconnect_request_internal(self, object msg) + cpdef void _handle_disconnect_request_internal(self, object msg) - cpdef _handle_ping_request_internal(self, object msg) + cpdef void _handle_ping_request_internal(self, object msg) - cpdef _handle_get_time_request_internal(self, object msg) + cpdef void _handle_get_time_request_internal(self, object msg) - cdef _set_fatal_exception_if_unset(self, Exception err) + cdef void _set_fatal_exception_if_unset(self, Exception err) + + cdef void _register_internal_message_handlers(self) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 521c2a3..11f2df8 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -12,7 +12,7 @@ import time from asyncio import CancelledError from asyncio import TimeoutError as asyncio_TimeoutError from dataclasses import astuple, dataclass -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Callable from google.protobuf import message @@ -63,9 +63,10 @@ _LOGGER = logging.getLogger(__name__) BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() -DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse() -PING_REQUEST_MESSAGE = PingRequest() -PING_RESPONSE_MESSAGE = PingResponse() +DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),) +PING_REQUEST_MESSAGES = (PingRequest(),) +PING_RESPONSE_MESSAGES = (PingResponse(),) +NO_PASSWORD_CONNECT_REQUEST = ConnectRequest() PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} @@ -131,6 +132,44 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED CONNECTION_STATE_CLOSED = ConnectionState.CLOSED +def _make_hello_request(client_info: str) -> HelloRequest: + """Make a HelloRequest.""" + return HelloRequest( + client_info=client_info, api_version_major=1, api_version_minor=9 + ) + + +_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request) +make_hello_request = _cached_make_hello_request + + +def handle_timeout(fut: asyncio.Future[None]) -> None: + """Handle a timeout.""" + if not fut.done(): + fut.set_exception(asyncio_TimeoutError) + + +_handle_timeout = handle_timeout + + +def handle_complex_message( + fut: asyncio.Future[None], + responses: list[message.Message], + do_append: Callable[[message.Message], bool] | None, + do_stop: Callable[[message.Message], bool] | None, + resp: message.Message, +) -> None: + """Handle a message that is part of a response.""" + if not fut.done(): + if do_append is None or do_append(resp): + responses.append(resp) + if do_stop is None or do_stop(resp): + fut.set_result(None) + + +_handle_complex_message = handle_complex_message + + class APIConnection: """This class represents _one_ connection to a remote native API device. @@ -331,12 +370,11 @@ class APIConnection: async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" fh: APIPlaintextFrameHelper | APINoiseFrameHelper - loop = self._loop if TYPE_CHECKING: assert self._socket is not None if (noise_psk := self._params.noise_psk) is None: - _, fh = await loop.create_connection( # type: ignore[type-var] + _, fh = await self._loop.create_connection( # type: ignore[type-var] lambda: APIPlaintextFrameHelper( connection=self, client_info=self._params.client_info, @@ -345,7 +383,7 @@ class APIConnection: sock=self._socket, ) else: - _, fh = await loop.create_connection( # type: ignore[type-var] + _, fh = await self._loop.create_connection( # type: ignore[type-var] lambda: APINoiseFrameHelper( noise_psk=noise_psk, expected_name=self._params.expected_name, @@ -359,24 +397,24 @@ class APIConnection: # Set the frame helper right away to ensure # the socket gets closed if we fail to handshake self._frame_helper = fh - + handshake_handle = self._loop.call_at( + self._loop.time() + HANDSHAKE_TIMEOUT, + _handle_timeout, + self._frame_helper.ready_future, + ) try: - await fh.perform_handshake(HANDSHAKE_TIMEOUT) + await self._frame_helper.ready_future except asyncio_TimeoutError as err: raise TimeoutAPIError("Handshake timed out") from err except OSError as err: raise HandshakeAPIError(f"Handshake failed: {err}") from err + finally: + handshake_handle.cancel() self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE) async def _connect_hello_login(self, login: bool) -> None: """Step 4 in connect process: send hello and login and get api version.""" - messages = [ - HelloRequest( - client_info=self._params.client_info, - api_version_major=1, - api_version_minor=9, - ) - ] + messages = [make_hello_request(self._params.client_info)] msg_types = [HelloResponse] if login: messages.append(self._make_connect_request()) @@ -447,16 +485,15 @@ class APIConnection: def _async_send_keep_alive(self) -> None: """Send a keep alive message.""" - loop = self._loop - now = loop.time() + now = self._loop.time() if self._send_pending_ping: - self.send_messages((PING_REQUEST_MESSAGE,)) + self.send_messages(PING_REQUEST_MESSAGES) if self._pong_timer is None: # Do not reset the timer if it's already set # since the only thing we want to reset the timer # is if we receive a pong. - self._pong_timer = loop.call_at( + self._pong_timer = self._loop.call_at( now + self._keep_alive_timeout, self._async_pong_not_received ) elif self._debug_enabled: @@ -600,10 +637,9 @@ class APIConnection: def _make_connect_request(self) -> ConnectRequest: """Make a ConnectRequest.""" - connect = ConnectRequest() if self._params.password is not None: - connect.password = self._params.password - return connect + return ConnectRequest(password=self._params.password) + return NO_PASSWORD_CONNECT_REQUEST def send_message(self, msg: message.Message) -> None: """Send a message to the remote.""" @@ -679,26 +715,6 @@ class APIConnection: # we register the handler after sending the message return self.add_message_callback(on_message, msg_types) - def _handle_timeout(self, fut: asyncio.Future[None]) -> None: - """Handle a timeout.""" - if not fut.done(): - fut.set_exception(asyncio_TimeoutError) - - def _handle_complex_message( - self, - fut: asyncio.Future[None], - responses: list[message.Message], - do_append: Callable[[message.Message], bool] | None, - do_stop: Callable[[message.Message], bool] | None, - resp: message.Message, - ) -> None: - """Handle a message that is part of a response.""" - if not fut.done(): - if do_append is None or do_append(resp): - responses.append(resp) - if do_stop is None or do_stop(resp): - fut.set_result(None) - async def send_messages_await_response_complex( # pylint: disable=too-many-locals self, messages: tuple[message.Message, ...], @@ -720,23 +736,23 @@ class APIConnection: # This is safe because we are not awaiting between # sending the message and registering the handler self.send_messages(messages) - loop = self._loop # Unsafe to await between sending the message and registering the handler - fut: asyncio.Future[None] = loop.create_future() + fut: asyncio.Future[None] = self._loop.create_future() responses: list[message.Message] = [] - handler = self._handle_complex_message - on_message = partial(handler, fut, responses, do_append, do_stop) - - read_exception_futures = self._read_exception_futures + on_message = partial( + _handle_complex_message, fut, responses, do_append, do_stop + ) self._add_message_callback_without_remove(on_message, msg_types) - read_exception_futures.add(fut) + self._read_exception_futures.add(fut) # Now safe to await since we have registered the handler # We must not await without a finally or # the message could fail to be removed if the # the await is cancelled - timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut) + timeout_handle = self._loop.call_at( + self._loop.time() + timeout, _handle_timeout, fut + ) timeout_expired = False try: await fut @@ -750,7 +766,7 @@ class APIConnection: if not timeout_expired: timeout_handle.cancel() self._remove_message_callback(on_message, msg_types) - read_exception_futures.discard(fut) + self._read_exception_futures.discard(fut) return responses @@ -775,7 +791,7 @@ class APIConnection: The connection will be closed, all exception handlers notified. This method does not log the error, the call site should do so. """ - if not self._fatal_exception: + if self._fatal_exception is None: if self._expected_disconnect is False: # Only log the first error _LOGGER.warning( @@ -810,7 +826,7 @@ class APIConnection: return try: - msg = klass() + msg: message.Message = klass() # MergeFromString instead of ParseFromString since # ParseFromString will clear the message first and # the msg is already empty. @@ -876,14 +892,14 @@ class APIConnection: # the response if for some reason sending the response # fails we will still mark the disconnect as expected self._expected_disconnect = True - self.send_messages((DISCONNECT_RESPONSE_MESSAGE,)) + self.send_messages(DISCONNECT_RESPONSE_MESSAGES) self._cleanup() def _handle_ping_request_internal( # pylint: disable=unused-argument self, _msg: PingRequest ) -> None: """Handle a PingRequest.""" - self.send_messages((PING_RESPONSE_MESSAGE,)) + self.send_messages(PING_RESPONSE_MESSAGES) def _handle_get_time_request_internal( # pylint: disable=unused-argument self, _msg: GetTimeRequest @@ -895,7 +911,7 @@ class APIConnection: async def disconnect(self) -> None: """Disconnect from the API.""" - if self._finish_connect_task: + if self._finish_connect_task is not None: # Try to wait for the handshake to finish so we can send # a disconnect request. If it doesn't finish in time # we will just close the socket. diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 875e9ca..f9f2818 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio import base64 -from datetime import timedelta from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -26,12 +25,7 @@ from aioesphomeapi.core import ( SocketClosedAPIError, ) -from .common import ( - async_fire_time_changed, - get_mock_protocol, - mock_data_received, - utcnow, -) +from .common import get_mock_protocol, mock_data_received from .conftest import get_mock_connection_params PREAMBLE = b"\x00" @@ -312,7 +306,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None: mock_data_received(helper, byte_type(bytes.fromhex(pkt))) with pytest.raises(InvalidEncryptionKeyAPIError): - await helper.perform_handshake(30) + await helper.ready_future @pytest.mark.asyncio @@ -343,7 +337,7 @@ async def test_noise_frame_helper_incorrect_key(): mock_data_received(helper, bytes.fromhex(pkt)) with pytest.raises(InvalidEncryptionKeyAPIError): - await helper.perform_handshake(30) + await helper.ready_future @pytest.mark.asyncio @@ -376,7 +370,7 @@ async def test_noise_frame_helper_incorrect_key_fragments(): mock_data_received(helper, in_pkt[i : i + 1]) with pytest.raises(InvalidEncryptionKeyAPIError): - await helper.perform_handshake(30) + await helper.ready_future @pytest.mark.asyncio @@ -407,36 +401,7 @@ async def test_noise_incorrect_name(): mock_data_received(helper, bytes.fromhex(pkt)) with pytest.raises(BadNameAPIError): - await helper.perform_handshake(30) - - -@pytest.mark.asyncio -async def test_noise_timeout(): - """Test we raise on bad name.""" - outgoing_packets = [ - "010000", # hello packet - "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", - ] - - connection, _ = _make_mock_connection() - - helper = MockAPINoiseFrameHelper( - connection=connection, - noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", - expected_name="wrongname", - client_info="my client", - log_name="test", - ) - - for pkt in outgoing_packets: - helper.mock_write_frame(bytes.fromhex(pkt)) - - task = asyncio.create_task(helper.perform_handshake(30)) - await asyncio.sleep(0) - async_fire_time_changed(utcnow() + timedelta(seconds=60)) - await asyncio.sleep(0) - with pytest.raises(HandshakeAPIError): - await task + await helper.ready_future VARUINT_TESTCASES = [ @@ -478,7 +443,6 @@ async def test_noise_frame_helper_handshake_failure(): proto = _mock_responder_proto(psk_bytes) - handshake_task = asyncio.create_task(helper.perform_handshake(30)) await asyncio.sleep(0) # let the task run to read the hello packet assert len(writes) == 1 @@ -502,7 +466,7 @@ async def test_noise_frame_helper_handshake_failure(): mock_data_received(helper, error_pkt_with_header) with pytest.raises(HandshakeAPIError, match="forced to fail"): - await handshake_task + await helper.ready_future @pytest.mark.asyncio @@ -528,7 +492,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): proto = _mock_responder_proto(psk_bytes) - handshake_task = asyncio.create_task(helper.perform_handshake(30)) await asyncio.sleep(0) # let the task run to read the hello packet assert len(writes) == 1 @@ -546,7 +509,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): assert not writes - await handshake_task + await helper.ready_future helper.write_packets([(1, b"to device")], True) encrypted_packet = writes.pop() header = encrypted_packet[0:1] @@ -591,7 +554,6 @@ async def test_noise_frame_helper_bad_encryption( proto = _mock_responder_proto(psk_bytes) - handshake_task = asyncio.create_task(helper.perform_handshake(30)) await asyncio.sleep(0) # let the task run to read the hello packet assert len(writes) == 1 @@ -609,7 +571,7 @@ async def test_noise_frame_helper_bad_encryption( assert not writes - await handshake_task + await helper.ready_future helper.write_packets([(1, b"to device")], True) encrypted_packet = writes.pop() header = encrypted_packet[0:1] @@ -638,7 +600,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): conn._socket = MagicMock() await conn._connect_init_frame_helper() - loop.call_soon(conn._frame_helper._ready_future.set_result, None) + loop.call_soon(conn._frame_helper.ready_future.set_result, None) conn.connection_state = ConnectionState.CONNECTED task = asyncio.create_task(conn._connect_hello_login(login=True)) @@ -687,13 +649,12 @@ async def test_noise_frame_helper_empty_hello(): log_name="test", ) - handshake_task = asyncio.create_task(helper.perform_handshake(30)) hello_pkt_with_header = _make_noise_hello_pkt(b"") mock_data_received(helper, hello_pkt_with_header) with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): - await handshake_task + await helper.ready_future @pytest.mark.asyncio @@ -708,7 +669,6 @@ async def test_noise_frame_helper_wrong_protocol(): log_name="test", ) - handshake_task = asyncio.create_task(helper.perform_handshake(30)) # wrong protocol 5 instead of 1 hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0") @@ -717,7 +677,7 @@ async def test_noise_frame_helper_wrong_protocol(): with pytest.raises( HandshakeAPIError, match="Unknown protocol selected by client 5" ): - await handshake_task + await helper.ready_future @pytest.mark.asyncio diff --git a/tests/test_connection.py b/tests/test_connection.py index 4f95593..d381dce 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,10 +2,9 @@ from __future__ import annotations import asyncio import logging -from collections.abc import Coroutine from datetime import timedelta from functools import partial -from typing import Any +from typing import Callable, cast from unittest.mock import AsyncMock, MagicMock, call, patch import pytest @@ -161,7 +160,7 @@ async def test_requires_encryption_propagates(conn: APIConnection): conn._socket = MagicMock() await conn._connect_init_frame_helper() - loop.call_soon(conn._frame_helper._ready_future.set_result, None) + loop.call_soon(conn._frame_helper.ready_future.set_result, None) conn.connection_state = ConnectionState.CONNECTED with pytest.raises(RequiresEncryptionAPIError): @@ -378,8 +377,18 @@ async def test_plaintext_connection_fails_handshake( class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper): """Plaintext frame helper that raises exception on handshake.""" - def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]: - raise exception + def _create_failing_mock_transport_protocol( + transport: asyncio.Transport, + connected: asyncio.Event, + create_func: Callable[[], APIPlaintextFrameHelper], + **kwargs, + ) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]: + protocol: APIPlaintextFrameHelperHandshakeException = create_func() + protocol._transport = cast(asyncio.Transport, transport) + protocol._writer = transport.write + protocol.ready_future.set_exception(exception) + connected.set() + return transport, protocol def on_msg(msg): messages.append(msg) @@ -393,7 +402,9 @@ async def test_plaintext_connection_fails_handshake( ), patch.object( loop, "create_connection", - side_effect=partial(_create_mock_transport_protocol, transport, connected), + side_effect=partial( + _create_failing_mock_transport_protocol, transport, connected + ), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait()