From e2bbbf4da51b22a08d40a785eef71bc9fa73e6e4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Feb 2024 20:47:26 -0600 Subject: [PATCH] Avoid creating tasks for starting/finishing the connection (#826) --- aioesphomeapi/connection.pxd | 8 +- aioesphomeapi/connection.py | 106 +++++++++++----------- requirements.txt | 1 + tests/test__frame_helper.py | 15 +++- tests/test_connection.py | 169 +++++++++++++++++++++++++++-------- tests/test_log_runner.py | 5 ++ 6 files changed, 210 insertions(+), 94 deletions(-) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 941c799..8cdc7b4 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -98,8 +98,8 @@ cdef class APIConnection: cdef object _pong_timer cdef float _keep_alive_interval cdef float _keep_alive_timeout - cdef object _start_connect_task - cdef object _finish_connect_task + cdef object _start_connect_future + cdef object _finish_connect_future cdef public Exception _fatal_exception cdef bint _expected_disconnect cdef object _loop @@ -154,3 +154,7 @@ cdef class APIConnection: cdef void _register_internal_message_handlers(self) cdef void _increase_recv_buffer_size(self) + + cdef void _set_start_connect_future(self) + + cdef void _set_finish_connect_future(self) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 9f08f53..408cd24 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -16,6 +16,7 @@ from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Callable import aiohappyeyeballs +from async_interrupt import interrupt from google.protobuf import message import aioesphomeapi.host_resolver as hr @@ -106,6 +107,10 @@ _bytes = bytes _float = float +class ConnectionInterruptedError(Exception): + """An error that is raised when a connection is interrupted.""" + + @dataclass class ConnectionParams: addresses: list[str] @@ -198,8 +203,8 @@ class APIConnection: "_pong_timer", "_keep_alive_interval", "_keep_alive_timeout", - "_start_connect_task", - "_finish_connect_task", + "_start_connect_future", + "_finish_connect_future", "_fatal_exception", "_expected_disconnect", "_loop", @@ -242,8 +247,8 @@ class APIConnection: self._keep_alive_interval = keepalive self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO - self._start_connect_task: asyncio.Task[None] | None = None - self._finish_connect_task: asyncio.Task[None] | None = None + self._start_connect_future: asyncio.Future[None] | None = None + self._finish_connect_future: asyncio.Future[None] | None = None self._fatal_exception: Exception | None = None self._expected_disconnect = False self._send_pending_ping = False @@ -276,28 +281,13 @@ class APIConnection: err = self._fatal_exception or APIConnectionError("Connection closed") new_exc = err if not isinstance(err, APIConnectionError): - new_exc = ReadFailedAPIError("Read failed") + new_exc = ReadFailedAPIError(str(err) or "Read failed") new_exc.__cause__ = err fut.set_exception(new_exc) self._read_exception_futures.clear() - # If we are being called from do_connect we - # need to make sure we don't cancel the task - # that called us - current_task = asyncio.current_task() - if ( - self._start_connect_task is not None - and self._start_connect_task is not current_task - ): - self._start_connect_task.cancel("Connection cleanup") - self._start_connect_task = None - - if ( - self._finish_connect_task is not None - and self._finish_connect_task is not current_task - ): - self._finish_connect_task.cancel("Connection cleanup") - self._finish_connect_task = None + self._set_start_connect_future() + self._set_finish_connect_future() if self._frame_helper is not None: self._frame_helper.close() @@ -460,7 +450,9 @@ class APIConnection: try: await self._frame_helper.ready_future except asyncio_TimeoutError as err: - raise TimeoutAPIError("Handshake timed out") from err + raise TimeoutAPIError( + f"Handshake timed out after {HANDSHAKE_TIMEOUT}s" + ) from err except OSError as err: raise HandshakeAPIError(f"Handshake failed: {err}") from err finally: @@ -475,19 +467,14 @@ class APIConnection: messages.append(self._make_connect_request()) msg_types.append(ConnectResponse) - try: - responses = await self.send_messages_await_response_complex( - tuple(messages), - None, - lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck - is msg_types[-1], - tuple(msg_types), - CONNECT_REQUEST_TIMEOUT, - ) - except TimeoutAPIError as err: - self.report_fatal_error(err) - raise TimeoutAPIError("Hello timed out") from err - + responses = await self.send_messages_await_response_complex( + tuple(messages), + None, + lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck + is msg_types[-1], + tuple(msg_types), + CONNECT_REQUEST_TIMEOUT, + ) resp = responses.pop(0) self._process_hello_resp(resp) if login: @@ -605,21 +592,29 @@ class APIConnection: "Connection can only be used once, connection is not in init state" ) - start_connect_task = asyncio.create_task( - self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect" - ) - self._start_connect_task = start_connect_task + self._start_connect_future = self._loop.create_future() try: - await start_connect_task + async with interrupt( + self._start_connect_future, ConnectionInterruptedError, None + ): + await self._do_connect() except (Exception, CancelledError) as ex: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError as APIConnectionError self._cleanup() raise self._wrap_fatal_connection_exception("starting", ex) finally: - self._start_connect_task = None + self._set_start_connect_future() self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED) + def _set_start_connect_future(self) -> None: + if ( + self._start_connect_future is not None + and not self._start_connect_future.done() + ): + self._start_connect_future.set_result(None) + self._start_connect_future = None + def _wrap_fatal_connection_exception( self, action: str, ex: BaseException ) -> APIConnectionError: @@ -627,7 +622,7 @@ class APIConnection: if isinstance(ex, APIConnectionError): return ex cause: BaseException | None = None - if isinstance(ex, CancelledError): + if isinstance(ex, (ConnectionInterruptedError, CancelledError)): err_str = f"{action.title()} connection cancelled" if self._fatal_exception: err_str += f" due to fatal exception: {self._fatal_exception}" @@ -664,22 +659,29 @@ class APIConnection: raise RuntimeError( "Connection must be in SOCKET_OPENED state to finish connection" ) - finish_connect_task = asyncio.create_task( - self._do_finish_connect(login), - name=f"{self.log_name}: aioesphomeapi _do_finish_connect", - ) - self._finish_connect_task = finish_connect_task + self._finish_connect_future = self._loop.create_future() try: - await self._finish_connect_task + async with interrupt( + self._finish_connect_future, ConnectionInterruptedError, None + ): + await self._do_finish_connect(login) except (Exception, CancelledError) as ex: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError as APIConnectionError self._cleanup() raise self._wrap_fatal_connection_exception("finishing", ex) finally: - self._finish_connect_task = None + self._set_finish_connect_future() self._set_connection_state(CONNECTION_STATE_CONNECTED) + def _set_finish_connect_future(self) -> None: + if ( + self._finish_connect_future is not None + and not self._finish_connect_future.done() + ): + self._finish_connect_future.set_result(None) + self._finish_connect_future = None + def _set_connection_state(self, state: ConnectionState) -> None: """Set the connection state and log the change.""" self.connection_state = state @@ -969,12 +971,12 @@ class APIConnection: async def disconnect(self) -> None: """Disconnect from the API.""" - if self._finish_connect_task is not None: + if self._finish_connect_future is not None: # Try to wait for the handshake to finish so we can send # a disconnect request. If it doesn't finish in time # we will just close the socket. _, pending = await asyncio.wait( - [self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT + [self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT ) if pending: self._set_fatal_exception_if_unset( diff --git a/requirements.txt b/requirements.txt index 06ccb08..e09dac0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ aiohappyeyeballs>=2.3.0 +async-interrupt>=1.1.1 protobuf>=3.19.0 zeroconf>=0.128.4,<1.0 chacha20poly1305-reuseable>=0.12.1 diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 640f48c..d4417f6 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -22,6 +22,7 @@ from aioesphomeapi.core import ( HandshakeAPIError, InvalidEncryptionKeyAPIError, ProtocolAPIError, + ReadFailedAPIError, SocketClosedAPIError, ) @@ -725,18 +726,28 @@ async def test_eof_received_closes_connection( await connect_task +@pytest.mark.parametrize( + ("exception_map"), + [ + (OSError("original message"), ReadFailedAPIError), + (APIConnectionError("original message"), APIConnectionError), + (SocketClosedAPIError("original message"), SocketClosedAPIError), + ], +) @pytest.mark.asyncio async def test_connection_lost_closes_connection_and_logs( caplog: pytest.LogCaptureFixture, plaintext_connect_task_with_login: tuple[ APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task ], + exception_map: tuple[Exception, Exception], ) -> None: + exception, raised_exception = exception_map conn, transport, protocol, connect_task = plaintext_connect_task_with_login - protocol.connection_lost(OSError("original message")) + protocol.connection_lost(exception) assert conn.is_connected is False assert "original message" in caplog.text - with pytest.raises(APIConnectionError, match="original message"): + with pytest.raises(raised_exception, match="original message"): await connect_task diff --git a/tests/test_connection.py b/tests/test_connection.py index 2b84748..e71763d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -24,12 +24,16 @@ from aioesphomeapi.api_pb2 import ( ) from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import ( + APIConnectionCancelledError, APIConnectionError, ConnectionNotEstablishedAPIError, HandshakeAPIError, InvalidAuthAPIError, + ReadFailedAPIError, RequiresEncryptionAPIError, ResolveAPIError, + SocketAPIError, + SocketClosedAPIError, TimeoutAPIError, ) @@ -442,7 +446,9 @@ async def test_finish_connection_times_out( async_fire_time_changed(utcnow() + timedelta(seconds=200)) await asyncio.sleep(0) - with pytest.raises(APIConnectionError, match="Hello timed out"): + with pytest.raises( + APIConnectionError, match="Timeout waiting for HelloResponse after 30.0s" + ): await connect_task async_fire_time_changed(utcnow() + timedelta(seconds=600)) @@ -458,6 +464,8 @@ async def test_finish_connection_times_out( ("exception_map"), [ (OSError("Socket error"), HandshakeAPIError), + (APIConnectionError, APIConnectionError), + (SocketClosedAPIError, SocketClosedAPIError), (asyncio.TimeoutError, TimeoutAPIError), (asyncio.CancelledError, APIConnectionError), ], @@ -501,6 +509,21 @@ async def test_plaintext_connection_fails_handshake( remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) transport = MagicMock() + call_order = [] + + def _socket_close_call(): + call_order.append("socket_close") + + def _frame_helper_close_call(): + call_order.append("frame_helper_close") + + async def _do_finish_connect(self, *args, **kwargs): + try: + await conn._connect_init_frame_helper() + finally: + conn._socket.close = _socket_close_call + conn._frame_helper.close = _frame_helper_close_call + with ( patch( "aioesphomeapi.connection.APIPlaintextFrameHelper", @@ -513,42 +536,12 @@ async def test_plaintext_connection_fails_handshake( _create_failing_mock_transport_protocol, transport, connected ), ), + patch.object(conn, "_do_finish_connect", _do_finish_connect), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() - protocol = conn._frame_helper - assert conn._socket is not None - assert conn._frame_helper is not None - - mock_data_received( - protocol, - b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m', - ) - mock_data_received(protocol, b"5stackatomproxy") - mock_data_received(protocol, b"\x00\x00$") - mock_data_received(protocol, b"\x00\x00\x04") - mock_data_received( - protocol, - b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d', - ) - mock_data_received( - protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" - ) - - call_order = [] - - def _socket_close_call(): - call_order.append("socket_close") - - def _frame_helper_close_call(): - call_order.append("frame_helper_close") - - with ( - patch.object(conn._socket, "close", side_effect=_socket_close_call), - patch.object(conn._frame_helper, "close", side_effect=_frame_helper_close_call), - pytest.raises(raised_exception), - ): + with (pytest.raises(raised_exception),): await asyncio.sleep(0) await connect_task @@ -556,10 +549,6 @@ async def test_plaintext_connection_fails_handshake( # so asyncio releases the socket assert call_order == ["frame_helper_close", "socket_close"] assert not conn.is_connected - assert len(messages) == 2 - assert isinstance(messages[0], HelloResponse) - assert isinstance(messages[1], DeviceInfoResponse) - assert messages[1].name == "m5stackatomproxy" remove() conn.force_disconnect() await asyncio.sleep(0) @@ -655,6 +644,110 @@ async def test_force_disconnect_fails( await asyncio.sleep(0) +@pytest.mark.parametrize( + ("exception_map"), + [ + (OSError("original message"), ReadFailedAPIError), + (APIConnectionError("original message"), APIConnectionError), + (SocketClosedAPIError("original message"), SocketClosedAPIError), + ], +) +@pytest.mark.asyncio +async def test_connection_lost_while_connecting( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], + exception_map: tuple[Exception, Exception], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + exception, raised_exception = exception_map + protocol.connection_lost(exception) + + with pytest.raises(raised_exception, match="original message"): + await connect_task + + assert not conn.is_connected + + +@pytest.mark.parametrize( + ("exception_map"), + [ + (OSError("original message"), SocketAPIError), + (APIConnectionError("original message"), APIConnectionError), + (SocketClosedAPIError("original message"), SocketClosedAPIError), + ], +) +@pytest.mark.asyncio +async def test_connection_error_during_hello( + conn: APIConnection, + resolve_host, + aiohappyeyeballs_start_connection, + exception_map: tuple[Exception, Exception], +) -> None: + loop = asyncio.get_event_loop() + transport = MagicMock() + connected = asyncio.Event() + exception, raised_exception = exception_map + + with ( + patch.object( + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ), + patch.object(conn, "_connect_hello_login", side_effect=exception), + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + + with pytest.raises(raised_exception, match="original message"): + await connect_task + + assert not conn.is_connected + + +@pytest.mark.parametrize( + ("exception_map"), + [ + (OSError("original message"), APIConnectionCancelledError), + (APIConnectionError("original message"), APIConnectionError), + (SocketClosedAPIError("original message"), SocketClosedAPIError), + ], +) +@pytest.mark.asyncio +async def test_connection_cancelled_during_hello( + conn: APIConnection, + resolve_host, + aiohappyeyeballs_start_connection, + exception_map: tuple[Exception, Exception], +) -> None: + loop = asyncio.get_event_loop() + transport = MagicMock() + connected = asyncio.Event() + exception, raised_exception = exception_map + + async def _mock_frame_helper_error(*args, **kwargs): + conn._frame_helper.connection_lost(exception) + raise asyncio.CancelledError + + with ( + patch.object( + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ), + patch.object(conn, "_connect_hello_login", _mock_frame_helper_error), + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + + with pytest.raises(raised_exception, match="original message"): + await connect_task + + assert not conn.is_connected + + @pytest.mark.asyncio async def test_connect_resolver_times_out( conn: APIConnection, aiohappyeyeballs_start_connection @@ -814,7 +907,7 @@ async def test_ping_disconnects_after_no_responses( start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1)) ) - assert transport.write.call_count == max_pings_to_disconnect_after + assert transport.write.call_count == max_pings_to_disconnect_after + 1 assert conn.is_connected is False diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index e95f5c5..addd429 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -252,6 +252,11 @@ async def test_log_runner_reconnects_on_subscribe_failure( stop_task = asyncio.create_task(stop()) await asyncio.sleep(0) + + send_plaintext_connect_response(protocol, False) + send_plaintext_hello(protocol) + disconnect_response = DisconnectResponse() mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) + await stop_task