Avoid creating tasks for starting/finishing the connection (#826)

This commit is contained in:
J. Nick Koston 2024-02-16 20:47:26 -06:00 committed by GitHub
parent 939c5296e3
commit e2bbbf4da5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 210 additions and 94 deletions

View File

@ -98,8 +98,8 @@ cdef class APIConnection:
cdef object _pong_timer cdef object _pong_timer
cdef float _keep_alive_interval cdef float _keep_alive_interval
cdef float _keep_alive_timeout cdef float _keep_alive_timeout
cdef object _start_connect_task cdef object _start_connect_future
cdef object _finish_connect_task cdef object _finish_connect_future
cdef public Exception _fatal_exception cdef public Exception _fatal_exception
cdef bint _expected_disconnect cdef bint _expected_disconnect
cdef object _loop cdef object _loop
@ -154,3 +154,7 @@ cdef class APIConnection:
cdef void _register_internal_message_handlers(self) cdef void _register_internal_message_handlers(self)
cdef void _increase_recv_buffer_size(self) cdef void _increase_recv_buffer_size(self)
cdef void _set_start_connect_future(self)
cdef void _set_finish_connect_future(self)

View File

@ -16,6 +16,7 @@ from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
import aiohappyeyeballs import aiohappyeyeballs
from async_interrupt import interrupt
from google.protobuf import message from google.protobuf import message
import aioesphomeapi.host_resolver as hr import aioesphomeapi.host_resolver as hr
@ -106,6 +107,10 @@ _bytes = bytes
_float = float _float = float
class ConnectionInterruptedError(Exception):
"""An error that is raised when a connection is interrupted."""
@dataclass @dataclass
class ConnectionParams: class ConnectionParams:
addresses: list[str] addresses: list[str]
@ -198,8 +203,8 @@ class APIConnection:
"_pong_timer", "_pong_timer",
"_keep_alive_interval", "_keep_alive_interval",
"_keep_alive_timeout", "_keep_alive_timeout",
"_start_connect_task", "_start_connect_future",
"_finish_connect_task", "_finish_connect_future",
"_fatal_exception", "_fatal_exception",
"_expected_disconnect", "_expected_disconnect",
"_loop", "_loop",
@ -242,8 +247,8 @@ class APIConnection:
self._keep_alive_interval = keepalive self._keep_alive_interval = keepalive
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._start_connect_task: asyncio.Task[None] | None = None self._start_connect_future: asyncio.Future[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None self._finish_connect_future: asyncio.Future[None] | None = None
self._fatal_exception: Exception | None = None self._fatal_exception: Exception | None = None
self._expected_disconnect = False self._expected_disconnect = False
self._send_pending_ping = False self._send_pending_ping = False
@ -276,28 +281,13 @@ class APIConnection:
err = self._fatal_exception or APIConnectionError("Connection closed") err = self._fatal_exception or APIConnectionError("Connection closed")
new_exc = err new_exc = err
if not isinstance(err, APIConnectionError): if not isinstance(err, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed") new_exc = ReadFailedAPIError(str(err) or "Read failed")
new_exc.__cause__ = err new_exc.__cause__ = err
fut.set_exception(new_exc) fut.set_exception(new_exc)
self._read_exception_futures.clear() self._read_exception_futures.clear()
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
current_task = asyncio.current_task()
if ( self._set_start_connect_future()
self._start_connect_task is not None self._set_finish_connect_future()
and self._start_connect_task is not current_task
):
self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None
if (
self._finish_connect_task is not None
and self._finish_connect_task is not current_task
):
self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None
if self._frame_helper is not None: if self._frame_helper is not None:
self._frame_helper.close() self._frame_helper.close()
@ -460,7 +450,9 @@ class APIConnection:
try: try:
await self._frame_helper.ready_future await self._frame_helper.ready_future
except asyncio_TimeoutError as err: except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err raise TimeoutAPIError(
f"Handshake timed out after {HANDSHAKE_TIMEOUT}s"
) from err
except OSError as err: except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err raise HandshakeAPIError(f"Handshake failed: {err}") from err
finally: finally:
@ -475,7 +467,6 @@ class APIConnection:
messages.append(self._make_connect_request()) messages.append(self._make_connect_request())
msg_types.append(ConnectResponse) msg_types.append(ConnectResponse)
try:
responses = await self.send_messages_await_response_complex( responses = await self.send_messages_await_response_complex(
tuple(messages), tuple(messages),
None, None,
@ -484,10 +475,6 @@ class APIConnection:
tuple(msg_types), tuple(msg_types),
CONNECT_REQUEST_TIMEOUT, CONNECT_REQUEST_TIMEOUT,
) )
except TimeoutAPIError as err:
self.report_fatal_error(err)
raise TimeoutAPIError("Hello timed out") from err
resp = responses.pop(0) resp = responses.pop(0)
self._process_hello_resp(resp) self._process_hello_resp(resp)
if login: if login:
@ -605,21 +592,29 @@ class APIConnection:
"Connection can only be used once, connection is not in init state" "Connection can only be used once, connection is not in init state"
) )
start_connect_task = asyncio.create_task( self._start_connect_future = self._loop.create_future()
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
)
self._start_connect_task = start_connect_task
try: try:
await start_connect_task async with interrupt(
self._start_connect_future, ConnectionInterruptedError, None
):
await self._do_connect()
except (Exception, CancelledError) as ex: except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection # If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError # and raise the CancelledError as APIConnectionError
self._cleanup() self._cleanup()
raise self._wrap_fatal_connection_exception("starting", ex) raise self._wrap_fatal_connection_exception("starting", ex)
finally: finally:
self._start_connect_task = None self._set_start_connect_future()
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED) self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
def _set_start_connect_future(self) -> None:
if (
self._start_connect_future is not None
and not self._start_connect_future.done()
):
self._start_connect_future.set_result(None)
self._start_connect_future = None
def _wrap_fatal_connection_exception( def _wrap_fatal_connection_exception(
self, action: str, ex: BaseException self, action: str, ex: BaseException
) -> APIConnectionError: ) -> APIConnectionError:
@ -627,7 +622,7 @@ class APIConnection:
if isinstance(ex, APIConnectionError): if isinstance(ex, APIConnectionError):
return ex return ex
cause: BaseException | None = None cause: BaseException | None = None
if isinstance(ex, CancelledError): if isinstance(ex, (ConnectionInterruptedError, CancelledError)):
err_str = f"{action.title()} connection cancelled" err_str = f"{action.title()} connection cancelled"
if self._fatal_exception: if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}" err_str += f" due to fatal exception: {self._fatal_exception}"
@ -664,22 +659,29 @@ class APIConnection:
raise RuntimeError( raise RuntimeError(
"Connection must be in SOCKET_OPENED state to finish connection" "Connection must be in SOCKET_OPENED state to finish connection"
) )
finish_connect_task = asyncio.create_task( self._finish_connect_future = self._loop.create_future()
self._do_finish_connect(login),
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
)
self._finish_connect_task = finish_connect_task
try: try:
await self._finish_connect_task async with interrupt(
self._finish_connect_future, ConnectionInterruptedError, None
):
await self._do_finish_connect(login)
except (Exception, CancelledError) as ex: except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection # If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError # and raise the CancelledError as APIConnectionError
self._cleanup() self._cleanup()
raise self._wrap_fatal_connection_exception("finishing", ex) raise self._wrap_fatal_connection_exception("finishing", ex)
finally: finally:
self._finish_connect_task = None self._set_finish_connect_future()
self._set_connection_state(CONNECTION_STATE_CONNECTED) self._set_connection_state(CONNECTION_STATE_CONNECTED)
def _set_finish_connect_future(self) -> None:
if (
self._finish_connect_future is not None
and not self._finish_connect_future.done()
):
self._finish_connect_future.set_result(None)
self._finish_connect_future = None
def _set_connection_state(self, state: ConnectionState) -> None: def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change.""" """Set the connection state and log the change."""
self.connection_state = state self.connection_state = state
@ -969,12 +971,12 @@ class APIConnection:
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from the API.""" """Disconnect from the API."""
if self._finish_connect_task is not None: if self._finish_connect_future is not None:
# Try to wait for the handshake to finish so we can send # Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time # a disconnect request. If it doesn't finish in time
# we will just close the socket. # we will just close the socket.
_, pending = await asyncio.wait( _, pending = await asyncio.wait(
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT [self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT
) )
if pending: if pending:
self._set_fatal_exception_if_unset( self._set_fatal_exception_if_unset(

View File

@ -1,4 +1,5 @@
aiohappyeyeballs>=2.3.0 aiohappyeyeballs>=2.3.0
async-interrupt>=1.1.1
protobuf>=3.19.0 protobuf>=3.19.0
zeroconf>=0.128.4,<1.0 zeroconf>=0.128.4,<1.0
chacha20poly1305-reuseable>=0.12.1 chacha20poly1305-reuseable>=0.12.1

View File

@ -22,6 +22,7 @@ from aioesphomeapi.core import (
HandshakeAPIError, HandshakeAPIError,
InvalidEncryptionKeyAPIError, InvalidEncryptionKeyAPIError,
ProtocolAPIError, ProtocolAPIError,
ReadFailedAPIError,
SocketClosedAPIError, SocketClosedAPIError,
) )
@ -725,18 +726,28 @@ async def test_eof_received_closes_connection(
await connect_task await connect_task
@pytest.mark.parametrize(
("exception_map"),
[
(OSError("original message"), ReadFailedAPIError),
(APIConnectionError("original message"), APIConnectionError),
(SocketClosedAPIError("original message"), SocketClosedAPIError),
],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connection_lost_closes_connection_and_logs( async def test_connection_lost_closes_connection_and_logs(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
plaintext_connect_task_with_login: tuple[ plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
], ],
exception_map: tuple[Exception, Exception],
) -> None: ) -> None:
exception, raised_exception = exception_map
conn, transport, protocol, connect_task = plaintext_connect_task_with_login conn, transport, protocol, connect_task = plaintext_connect_task_with_login
protocol.connection_lost(OSError("original message")) protocol.connection_lost(exception)
assert conn.is_connected is False assert conn.is_connected is False
assert "original message" in caplog.text assert "original message" in caplog.text
with pytest.raises(APIConnectionError, match="original message"): with pytest.raises(raised_exception, match="original message"):
await connect_task await connect_task

View File

@ -24,12 +24,16 @@ from aioesphomeapi.api_pb2 import (
) )
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import ( from aioesphomeapi.core import (
APIConnectionCancelledError,
APIConnectionError, APIConnectionError,
ConnectionNotEstablishedAPIError, ConnectionNotEstablishedAPIError,
HandshakeAPIError, HandshakeAPIError,
InvalidAuthAPIError, InvalidAuthAPIError,
ReadFailedAPIError,
RequiresEncryptionAPIError, RequiresEncryptionAPIError,
ResolveAPIError, ResolveAPIError,
SocketAPIError,
SocketClosedAPIError,
TimeoutAPIError, TimeoutAPIError,
) )
@ -442,7 +446,9 @@ async def test_finish_connection_times_out(
async_fire_time_changed(utcnow() + timedelta(seconds=200)) async_fire_time_changed(utcnow() + timedelta(seconds=200))
await asyncio.sleep(0) await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Hello timed out"): with pytest.raises(
APIConnectionError, match="Timeout waiting for HelloResponse after 30.0s"
):
await connect_task await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600)) async_fire_time_changed(utcnow() + timedelta(seconds=600))
@ -458,6 +464,8 @@ async def test_finish_connection_times_out(
("exception_map"), ("exception_map"),
[ [
(OSError("Socket error"), HandshakeAPIError), (OSError("Socket error"), HandshakeAPIError),
(APIConnectionError, APIConnectionError),
(SocketClosedAPIError, SocketClosedAPIError),
(asyncio.TimeoutError, TimeoutAPIError), (asyncio.TimeoutError, TimeoutAPIError),
(asyncio.CancelledError, APIConnectionError), (asyncio.CancelledError, APIConnectionError),
], ],
@ -501,6 +509,21 @@ async def test_plaintext_connection_fails_handshake(
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock() transport = MagicMock()
call_order = []
def _socket_close_call():
call_order.append("socket_close")
def _frame_helper_close_call():
call_order.append("frame_helper_close")
async def _do_finish_connect(self, *args, **kwargs):
try:
await conn._connect_init_frame_helper()
finally:
conn._socket.close = _socket_close_call
conn._frame_helper.close = _frame_helper_close_call
with ( with (
patch( patch(
"aioesphomeapi.connection.APIPlaintextFrameHelper", "aioesphomeapi.connection.APIPlaintextFrameHelper",
@ -513,42 +536,12 @@ async def test_plaintext_connection_fails_handshake(
_create_failing_mock_transport_protocol, transport, connected _create_failing_mock_transport_protocol, transport, connected
), ),
), ),
patch.object(conn, "_do_finish_connect", _do_finish_connect),
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
protocol = conn._frame_helper with (pytest.raises(raised_exception),):
assert conn._socket is not None
assert conn._frame_helper is not None
mock_data_received(
protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
)
mock_data_received(protocol, b"5stackatomproxy")
mock_data_received(protocol, b"\x00\x00$")
mock_data_received(protocol, b"\x00\x00\x04")
mock_data_received(
protocol,
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
)
mock_data_received(
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
call_order = []
def _socket_close_call():
call_order.append("socket_close")
def _frame_helper_close_call():
call_order.append("frame_helper_close")
with (
patch.object(conn._socket, "close", side_effect=_socket_close_call),
patch.object(conn._frame_helper, "close", side_effect=_frame_helper_close_call),
pytest.raises(raised_exception),
):
await asyncio.sleep(0) await asyncio.sleep(0)
await connect_task await connect_task
@ -556,10 +549,6 @@ async def test_plaintext_connection_fails_handshake(
# so asyncio releases the socket # so asyncio releases the socket
assert call_order == ["frame_helper_close", "socket_close"] assert call_order == ["frame_helper_close", "socket_close"]
assert not conn.is_connected assert not conn.is_connected
assert len(messages) == 2
assert isinstance(messages[0], HelloResponse)
assert isinstance(messages[1], DeviceInfoResponse)
assert messages[1].name == "m5stackatomproxy"
remove() remove()
conn.force_disconnect() conn.force_disconnect()
await asyncio.sleep(0) await asyncio.sleep(0)
@ -655,6 +644,110 @@ async def test_force_disconnect_fails(
await asyncio.sleep(0) await asyncio.sleep(0)
@pytest.mark.parametrize(
("exception_map"),
[
(OSError("original message"), ReadFailedAPIError),
(APIConnectionError("original message"), APIConnectionError),
(SocketClosedAPIError("original message"), SocketClosedAPIError),
],
)
@pytest.mark.asyncio
async def test_connection_lost_while_connecting(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
exception_map: tuple[Exception, Exception],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
exception, raised_exception = exception_map
protocol.connection_lost(exception)
with pytest.raises(raised_exception, match="original message"):
await connect_task
assert not conn.is_connected
@pytest.mark.parametrize(
("exception_map"),
[
(OSError("original message"), SocketAPIError),
(APIConnectionError("original message"), APIConnectionError),
(SocketClosedAPIError("original message"), SocketClosedAPIError),
],
)
@pytest.mark.asyncio
async def test_connection_error_during_hello(
conn: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection,
exception_map: tuple[Exception, Exception],
) -> None:
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
exception, raised_exception = exception_map
with (
patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
),
patch.object(conn, "_connect_hello_login", side_effect=exception),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
with pytest.raises(raised_exception, match="original message"):
await connect_task
assert not conn.is_connected
@pytest.mark.parametrize(
("exception_map"),
[
(OSError("original message"), APIConnectionCancelledError),
(APIConnectionError("original message"), APIConnectionError),
(SocketClosedAPIError("original message"), SocketClosedAPIError),
],
)
@pytest.mark.asyncio
async def test_connection_cancelled_during_hello(
conn: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection,
exception_map: tuple[Exception, Exception],
) -> None:
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
exception, raised_exception = exception_map
async def _mock_frame_helper_error(*args, **kwargs):
conn._frame_helper.connection_lost(exception)
raise asyncio.CancelledError
with (
patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
),
patch.object(conn, "_connect_hello_login", _mock_frame_helper_error),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
with pytest.raises(raised_exception, match="original message"):
await connect_task
assert not conn.is_connected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_resolver_times_out( async def test_connect_resolver_times_out(
conn: APIConnection, aiohappyeyeballs_start_connection conn: APIConnection, aiohappyeyeballs_start_connection
@ -814,7 +907,7 @@ async def test_ping_disconnects_after_no_responses(
start_time start_time
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1)) + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
) )
assert transport.write.call_count == max_pings_to_disconnect_after assert transport.write.call_count == max_pings_to_disconnect_after + 1
assert conn.is_connected is False assert conn.is_connected is False

View File

@ -252,6 +252,11 @@ async def test_log_runner_reconnects_on_subscribe_failure(
stop_task = asyncio.create_task(stop()) stop_task = asyncio.create_task(stop())
await asyncio.sleep(0) await asyncio.sleep(0)
send_plaintext_connect_response(protocol, False)
send_plaintext_hello(protocol)
disconnect_response = DisconnectResponse() disconnect_response = DisconnectResponse()
mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
await stop_task await stop_task