Reduce duplicate code between connection and frame helper (#763)

This commit is contained in:
J. Nick Koston 2023-11-27 23:51:38 -06:00 committed by GitHub
parent 5fb9c9243b
commit 1b51530642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 147 additions and 167 deletions

View File

@ -13,7 +13,7 @@ cdef class APIFrameHelper:
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _ready_future
cdef public object ready_future
cdef bytes _buffer
cdef unsigned int _buffer_len
cdef unsigned int _pos

View File

@ -5,7 +5,7 @@ import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError
from ..core import SocketClosedAPIError
if TYPE_CHECKING:
from ..connection import APIConnection
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
_int = int
_bytes = bytes
_float = float
class APIFrameHelper:
@ -33,7 +34,7 @@ class APIFrameHelper:
"_connection",
"_transport",
"_writer",
"_ready_future",
"ready_future",
"_buffer",
"_buffer_len",
"_pos",
@ -53,7 +54,7 @@ class APIFrameHelper:
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future()
self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None
self._buffer_len = 0
self._pos = 0
@ -65,8 +66,8 @@ class APIFrameHelper:
self._log_name = log_name
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
if not self.ready_future.done():
self.ready_future.set_exception(exc)
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
"""Add data to the buffer."""
@ -135,22 +136,6 @@ class APIFrameHelper:
bitpos += 7
return -1
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
handshake_handle = self._loop.call_at(
self._loop.time() + timeout,
self._set_ready_future_exception,
asyncio.TimeoutError,
)
try:
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
@abstractmethod
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._ready_future.set_result(None)
self.ready_future.set_result(None)
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
super().connection_made(transport)
self._ready_future.set_result(None)
self.ready_future.set_result(None)
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -71,7 +71,6 @@ from .api_pb2 import ( # type: ignore
VoiceAssistantResponse,
)
from .client_callbacks import (
handle_timeout,
on_ble_raw_advertisement_response,
on_bluetooth_connections_free_response,
on_bluetooth_device_connection_response,
@ -81,7 +80,7 @@ from .client_callbacks import (
on_state_msg,
on_subscribe_home_assistant_state_response,
)
from .connection import APIConnection, ConnectionParams
from .connection import APIConnection, ConnectionParams, handle_timeout
from .core import (
APIConnectionError,
BluetoothGATTAPIError,

View File

@ -8,5 +8,3 @@ cdef object CameraImageResponse, CameraState
cdef object HomeassistantServiceCall
cdef object BluetoothLEAdvertisement
cdef object asyncio_TimeoutError

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from asyncio import Future
from asyncio import TimeoutError as asyncio_TimeoutError
from typing import TYPE_CHECKING, Callable
from google.protobuf import message
@ -98,12 +97,6 @@ def on_subscribe_home_assistant_state_response(
on_state_sub(msg.entity_id, msg.attribute)
def handle_timeout(fut: Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def on_bluetooth_device_connection_response(
connect_future: Future[None],
address: int,

View File

@ -9,19 +9,21 @@ cdef dict PROTO_TO_MESSAGE_TYPE
cdef set OPEN_STATES
cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef object HANDSHAKE_TIMEOUT
cdef bint TYPE_CHECKING
cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE
cdef tuple DISCONNECT_RESPONSE_MESSAGES
cdef tuple PING_REQUEST_MESSAGES
cdef tuple PING_RESPONSE_MESSAGES
cdef object NO_PASSWORD_CONNECT_REQUEST
cdef object asyncio_timeout
cdef object CancelledError
cdef object asyncio_TimeoutError
cdef object ConnectResponse
cdef object ConnectRequest, ConnectResponse
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest, GetTimeResponse
@ -53,6 +55,20 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
cdef object CONNECTION_STATE_CONNECTED
cdef object CONNECTION_STATE_CLOSED
cdef object make_hello_request
cpdef handle_timeout(object fut)
cpdef handle_complex_message(
object fut,
list responses,
object do_append,
object do_stop,
object resp,
)
cdef object _handle_timeout
cdef object _handle_complex_message
@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
@ -91,43 +107,45 @@ cdef class APIConnection:
cdef public str received_name
cdef public object resolved_addr_info
cpdef send_message(self, object msg)
cpdef void send_message(self, object msg)
cdef send_messages(self, tuple messages)
cdef void send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set)
cpdef void process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self)
cdef void _async_cancel_pong_timer(self)
cpdef _async_schedule_keep_alive(self, object now)
cdef void _async_schedule_keep_alive(self, object now)
cdef _cleanup(self)
cdef void _cleanup(self)
cpdef set_log_name(self, str name)
cdef _make_connect_request(self)
cdef _process_hello_resp(self, object resp)
cdef void _process_hello_resp(self, object resp)
cdef _process_login_response(self, object hello_response)
cdef void _process_login_response(self, object hello_response)
cdef _set_connection_state(self, object state)
cdef void _set_connection_state(self, object state)
cpdef report_fatal_error(self, Exception err)
@cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cpdef add_message_callback(self, object on_message, tuple msg_types)
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
cpdef _handle_disconnect_request_internal(self, object msg)
cpdef void _handle_disconnect_request_internal(self, object msg)
cpdef _handle_ping_request_internal(self, object msg)
cpdef void _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg)
cpdef void _handle_get_time_request_internal(self, object msg)
cdef _set_fatal_exception_if_unset(self, Exception err)
cdef void _set_fatal_exception_if_unset(self, Exception err)
cdef void _register_internal_message_handlers(self)

View File

@ -12,7 +12,7 @@ import time
from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError
from dataclasses import astuple, dataclass
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable
from google.protobuf import message
@ -63,9 +63,10 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()
DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
PING_REQUEST_MESSAGES = (PingRequest(),)
PING_RESPONSE_MESSAGES = (PingResponse(),)
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@ -131,6 +132,44 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
def _make_hello_request(client_info: str) -> HelloRequest:
"""Make a HelloRequest."""
return HelloRequest(
client_info=client_info, api_version_major=1, api_version_minor=9
)
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
make_hello_request = _cached_make_hello_request
def handle_timeout(fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
_handle_timeout = handle_timeout
def handle_complex_message(
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
_handle_complex_message = handle_complex_message
class APIConnection:
"""This class represents _one_ connection to a remote native API device.
@ -331,12 +370,11 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop
if TYPE_CHECKING:
assert self._socket is not None
if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var]
_, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper(
connection=self,
client_info=self._params.client_info,
@ -345,7 +383,7 @@ class APIConnection:
sock=self._socket,
)
else:
_, fh = await loop.create_connection( # type: ignore[type-var]
_, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper(
noise_psk=noise_psk,
expected_name=self._params.expected_name,
@ -359,24 +397,24 @@ class APIConnection:
# Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake
self._frame_helper = fh
handshake_handle = self._loop.call_at(
self._loop.time() + HANDSHAKE_TIMEOUT,
_handle_timeout,
self._frame_helper.ready_future,
)
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
await self._frame_helper.ready_future
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
finally:
handshake_handle.cancel()
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
async def _connect_hello_login(self, login: bool) -> None:
"""Step 4 in connect process: send hello and login and get api version."""
messages = [
HelloRequest(
client_info=self._params.client_info,
api_version_major=1,
api_version_minor=9,
)
]
messages = [make_hello_request(self._params.client_info)]
msg_types = [HelloResponse]
if login:
messages.append(self._make_connect_request())
@ -447,16 +485,15 @@ class APIConnection:
def _async_send_keep_alive(self) -> None:
"""Send a keep alive message."""
loop = self._loop
now = loop.time()
now = self._loop.time()
if self._send_pending_ping:
self.send_messages((PING_REQUEST_MESSAGE,))
self.send_messages(PING_REQUEST_MESSAGES)
if self._pong_timer is None:
# Do not reset the timer if it's already set
# since the only thing we want to reset the timer
# is if we receive a pong.
self._pong_timer = loop.call_at(
self._pong_timer = self._loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received
)
elif self._debug_enabled:
@ -600,10 +637,9 @@ class APIConnection:
def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest."""
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
return connect
return ConnectRequest(password=self._params.password)
return NO_PASSWORD_CONNECT_REQUEST
def send_message(self, msg: message.Message) -> None:
"""Send a message to the remote."""
@ -679,26 +715,6 @@ class APIConnection:
# we register the handler after sending the message
return self.add_message_callback(on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def _handle_complex_message(
self,
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
self,
messages: tuple[message.Message, ...],
@ -720,23 +736,23 @@ class APIConnection:
# This is safe because we are not awaiting between
# sending the message and registering the handler
self.send_messages(messages)
loop = self._loop
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()
fut: asyncio.Future[None] = self._loop.create_future()
responses: list[message.Message] = []
handler = self._handle_complex_message
on_message = partial(handler, fut, responses, do_append, do_stop)
read_exception_futures = self._read_exception_futures
on_message = partial(
_handle_complex_message, fut, responses, do_append, do_stop
)
self._add_message_callback_without_remove(on_message, msg_types)
read_exception_futures.add(fut)
self._read_exception_futures.add(fut)
# Now safe to await since we have registered the handler
# We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut)
timeout_handle = self._loop.call_at(
self._loop.time() + timeout, _handle_timeout, fut
)
timeout_expired = False
try:
await fut
@ -750,7 +766,7 @@ class APIConnection:
if not timeout_expired:
timeout_handle.cancel()
self._remove_message_callback(on_message, msg_types)
read_exception_futures.discard(fut)
self._read_exception_futures.discard(fut)
return responses
@ -775,7 +791,7 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if not self._fatal_exception:
if self._fatal_exception is None:
if self._expected_disconnect is False:
# Only log the first error
_LOGGER.warning(
@ -810,7 +826,7 @@ class APIConnection:
return
try:
msg = klass()
msg: message.Message = klass()
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
@ -876,14 +892,14 @@ class APIConnection:
# the response if for some reason sending the response
# fails we will still mark the disconnect as expected
self._expected_disconnect = True
self.send_messages((DISCONNECT_RESPONSE_MESSAGE,))
self.send_messages(DISCONNECT_RESPONSE_MESSAGES)
self._cleanup()
def _handle_ping_request_internal( # pylint: disable=unused-argument
self, _msg: PingRequest
) -> None:
"""Handle a PingRequest."""
self.send_messages((PING_RESPONSE_MESSAGE,))
self.send_messages(PING_RESPONSE_MESSAGES)
def _handle_get_time_request_internal( # pylint: disable=unused-argument
self, _msg: GetTimeRequest
@ -895,7 +911,7 @@ class APIConnection:
async def disconnect(self) -> None:
"""Disconnect from the API."""
if self._finish_connect_task:
if self._finish_connect_task is not None:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
import base64
from datetime import timedelta
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -26,12 +25,7 @@ from aioesphomeapi.core import (
SocketClosedAPIError,
)
from .common import (
async_fire_time_changed,
get_mock_protocol,
mock_data_received,
utcnow,
)
from .common import get_mock_protocol, mock_data_received
from .conftest import get_mock_connection_params
PREAMBLE = b"\x00"
@ -312,7 +306,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -343,7 +337,7 @@ async def test_noise_frame_helper_incorrect_key():
mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -376,7 +370,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
mock_data_received(helper, in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -407,36 +401,7 @@ async def test_noise_incorrect_name():
mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30)
@pytest.mark.asyncio
async def test_noise_timeout():
"""Test we raise on bad name."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
log_name="test",
)
for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt))
task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=60))
await asyncio.sleep(0)
with pytest.raises(HandshakeAPIError):
await task
await helper.ready_future
VARUINT_TESTCASES = [
@ -478,7 +443,6 @@ async def test_noise_frame_helper_handshake_failure():
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -502,7 +466,7 @@ async def test_noise_frame_helper_handshake_failure():
mock_data_received(helper, error_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="forced to fail"):
await handshake_task
await helper.ready_future
@pytest.mark.asyncio
@ -528,7 +492,6 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -546,7 +509,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert not writes
await handshake_task
await helper.ready_future
helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop()
header = encrypted_packet[0:1]
@ -591,7 +554,6 @@ async def test_noise_frame_helper_bad_encryption(
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -609,7 +571,7 @@ async def test_noise_frame_helper_bad_encryption(
assert not writes
await handshake_task
await helper.ready_future
helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop()
header = encrypted_packet[0:1]
@ -638,7 +600,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
task = asyncio.create_task(conn._connect_hello_login(login=True))
@ -687,13 +649,12 @@ async def test_noise_frame_helper_empty_hello():
log_name="test",
)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
hello_pkt_with_header = _make_noise_hello_pkt(b"")
mock_data_received(helper, hello_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
await handshake_task
await helper.ready_future
@pytest.mark.asyncio
@ -708,7 +669,6 @@ async def test_noise_frame_helper_wrong_protocol():
log_name="test",
)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
# wrong protocol 5 instead of 1
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
@ -717,7 +677,7 @@ async def test_noise_frame_helper_wrong_protocol():
with pytest.raises(
HandshakeAPIError, match="Unknown protocol selected by client 5"
):
await handshake_task
await helper.ready_future
@pytest.mark.asyncio

View File

@ -2,10 +2,9 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import Coroutine
from datetime import timedelta
from functools import partial
from typing import Any
from typing import Callable, cast
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
@ -161,7 +160,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError):
@ -378,8 +377,18 @@ async def test_plaintext_connection_fails_handshake(
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
"""Plaintext frame helper that raises exception on handshake."""
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
raise exception
def _create_failing_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
create_func: Callable[[], APIPlaintextFrameHelper],
**kwargs,
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
protocol._transport = cast(asyncio.Transport, transport)
protocol._writer = transport.write
protocol.ready_future.set_exception(exception)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
@ -393,7 +402,9 @@ async def test_plaintext_connection_fails_handshake(
), patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
side_effect=partial(
_create_failing_mock_transport_protocol, transport, connected
),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()