mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-30 18:08:36 +01:00
Avoid creating tasks for starting/finishing the connection (#826)
This commit is contained in:
parent
939c5296e3
commit
e2bbbf4da5
@ -98,8 +98,8 @@ cdef class APIConnection:
|
||||
cdef object _pong_timer
|
||||
cdef float _keep_alive_interval
|
||||
cdef float _keep_alive_timeout
|
||||
cdef object _start_connect_task
|
||||
cdef object _finish_connect_task
|
||||
cdef object _start_connect_future
|
||||
cdef object _finish_connect_future
|
||||
cdef public Exception _fatal_exception
|
||||
cdef bint _expected_disconnect
|
||||
cdef object _loop
|
||||
@ -154,3 +154,7 @@ cdef class APIConnection:
|
||||
cdef void _register_internal_message_handlers(self)
|
||||
|
||||
cdef void _increase_recv_buffer_size(self)
|
||||
|
||||
cdef void _set_start_connect_future(self)
|
||||
|
||||
cdef void _set_finish_connect_future(self)
|
||||
|
@ -16,6 +16,7 @@ from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import aiohappyeyeballs
|
||||
from async_interrupt import interrupt
|
||||
from google.protobuf import message
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
@ -106,6 +107,10 @@ _bytes = bytes
|
||||
_float = float
|
||||
|
||||
|
||||
class ConnectionInterruptedError(Exception):
|
||||
"""An error that is raised when a connection is interrupted."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
addresses: list[str]
|
||||
@ -198,8 +203,8 @@ class APIConnection:
|
||||
"_pong_timer",
|
||||
"_keep_alive_interval",
|
||||
"_keep_alive_timeout",
|
||||
"_start_connect_task",
|
||||
"_finish_connect_task",
|
||||
"_start_connect_future",
|
||||
"_finish_connect_future",
|
||||
"_fatal_exception",
|
||||
"_expected_disconnect",
|
||||
"_loop",
|
||||
@ -242,8 +247,8 @@ class APIConnection:
|
||||
self._keep_alive_interval = keepalive
|
||||
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||
|
||||
self._start_connect_task: asyncio.Task[None] | None = None
|
||||
self._finish_connect_task: asyncio.Task[None] | None = None
|
||||
self._start_connect_future: asyncio.Future[None] | None = None
|
||||
self._finish_connect_future: asyncio.Future[None] | None = None
|
||||
self._fatal_exception: Exception | None = None
|
||||
self._expected_disconnect = False
|
||||
self._send_pending_ping = False
|
||||
@ -276,28 +281,13 @@ class APIConnection:
|
||||
err = self._fatal_exception or APIConnectionError("Connection closed")
|
||||
new_exc = err
|
||||
if not isinstance(err, APIConnectionError):
|
||||
new_exc = ReadFailedAPIError("Read failed")
|
||||
new_exc = ReadFailedAPIError(str(err) or "Read failed")
|
||||
new_exc.__cause__ = err
|
||||
fut.set_exception(new_exc)
|
||||
self._read_exception_futures.clear()
|
||||
# If we are being called from do_connect we
|
||||
# need to make sure we don't cancel the task
|
||||
# that called us
|
||||
current_task = asyncio.current_task()
|
||||
|
||||
if (
|
||||
self._start_connect_task is not None
|
||||
and self._start_connect_task is not current_task
|
||||
):
|
||||
self._start_connect_task.cancel("Connection cleanup")
|
||||
self._start_connect_task = None
|
||||
|
||||
if (
|
||||
self._finish_connect_task is not None
|
||||
and self._finish_connect_task is not current_task
|
||||
):
|
||||
self._finish_connect_task.cancel("Connection cleanup")
|
||||
self._finish_connect_task = None
|
||||
self._set_start_connect_future()
|
||||
self._set_finish_connect_future()
|
||||
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
@ -460,7 +450,9 @@ class APIConnection:
|
||||
try:
|
||||
await self._frame_helper.ready_future
|
||||
except asyncio_TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
raise TimeoutAPIError(
|
||||
f"Handshake timed out after {HANDSHAKE_TIMEOUT}s"
|
||||
) from err
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
finally:
|
||||
@ -475,19 +467,14 @@ class APIConnection:
|
||||
messages.append(self._make_connect_request())
|
||||
msg_types.append(ConnectResponse)
|
||||
|
||||
try:
|
||||
responses = await self.send_messages_await_response_complex(
|
||||
tuple(messages),
|
||||
None,
|
||||
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
|
||||
is msg_types[-1],
|
||||
tuple(msg_types),
|
||||
CONNECT_REQUEST_TIMEOUT,
|
||||
)
|
||||
except TimeoutAPIError as err:
|
||||
self.report_fatal_error(err)
|
||||
raise TimeoutAPIError("Hello timed out") from err
|
||||
|
||||
responses = await self.send_messages_await_response_complex(
|
||||
tuple(messages),
|
||||
None,
|
||||
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
|
||||
is msg_types[-1],
|
||||
tuple(msg_types),
|
||||
CONNECT_REQUEST_TIMEOUT,
|
||||
)
|
||||
resp = responses.pop(0)
|
||||
self._process_hello_resp(resp)
|
||||
if login:
|
||||
@ -605,21 +592,29 @@ class APIConnection:
|
||||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
|
||||
start_connect_task = asyncio.create_task(
|
||||
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
)
|
||||
self._start_connect_task = start_connect_task
|
||||
self._start_connect_future = self._loop.create_future()
|
||||
try:
|
||||
await start_connect_task
|
||||
async with interrupt(
|
||||
self._start_connect_future, ConnectionInterruptedError, None
|
||||
):
|
||||
await self._do_connect()
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
raise self._wrap_fatal_connection_exception("starting", ex)
|
||||
finally:
|
||||
self._start_connect_task = None
|
||||
self._set_start_connect_future()
|
||||
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
|
||||
|
||||
def _set_start_connect_future(self) -> None:
|
||||
if (
|
||||
self._start_connect_future is not None
|
||||
and not self._start_connect_future.done()
|
||||
):
|
||||
self._start_connect_future.set_result(None)
|
||||
self._start_connect_future = None
|
||||
|
||||
def _wrap_fatal_connection_exception(
|
||||
self, action: str, ex: BaseException
|
||||
) -> APIConnectionError:
|
||||
@ -627,7 +622,7 @@ class APIConnection:
|
||||
if isinstance(ex, APIConnectionError):
|
||||
return ex
|
||||
cause: BaseException | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
if isinstance(ex, (ConnectionInterruptedError, CancelledError)):
|
||||
err_str = f"{action.title()} connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
@ -664,22 +659,29 @@ class APIConnection:
|
||||
raise RuntimeError(
|
||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||
)
|
||||
finish_connect_task = asyncio.create_task(
|
||||
self._do_finish_connect(login),
|
||||
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
|
||||
)
|
||||
self._finish_connect_task = finish_connect_task
|
||||
self._finish_connect_future = self._loop.create_future()
|
||||
try:
|
||||
await self._finish_connect_task
|
||||
async with interrupt(
|
||||
self._finish_connect_future, ConnectionInterruptedError, None
|
||||
):
|
||||
await self._do_finish_connect(login)
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
raise self._wrap_fatal_connection_exception("finishing", ex)
|
||||
finally:
|
||||
self._finish_connect_task = None
|
||||
self._set_finish_connect_future()
|
||||
self._set_connection_state(CONNECTION_STATE_CONNECTED)
|
||||
|
||||
def _set_finish_connect_future(self) -> None:
|
||||
if (
|
||||
self._finish_connect_future is not None
|
||||
and not self._finish_connect_future.done()
|
||||
):
|
||||
self._finish_connect_future.set_result(None)
|
||||
self._finish_connect_future = None
|
||||
|
||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||
"""Set the connection state and log the change."""
|
||||
self.connection_state = state
|
||||
@ -969,12 +971,12 @@ class APIConnection:
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the API."""
|
||||
if self._finish_connect_task is not None:
|
||||
if self._finish_connect_future is not None:
|
||||
# Try to wait for the handshake to finish so we can send
|
||||
# a disconnect request. If it doesn't finish in time
|
||||
# we will just close the socket.
|
||||
_, pending = await asyncio.wait(
|
||||
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
|
||||
[self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT
|
||||
)
|
||||
if pending:
|
||||
self._set_fatal_exception_if_unset(
|
||||
|
@ -1,4 +1,5 @@
|
||||
aiohappyeyeballs>=2.3.0
|
||||
async-interrupt>=1.1.1
|
||||
protobuf>=3.19.0
|
||||
zeroconf>=0.128.4,<1.0
|
||||
chacha20poly1305-reuseable>=0.12.1
|
||||
|
@ -22,6 +22,7 @@ from aioesphomeapi.core import (
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
ReadFailedAPIError,
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
|
||||
@ -725,18 +726,28 @@ async def test_eof_received_closes_connection(
|
||||
await connect_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("original message"), ReadFailedAPIError),
|
||||
(APIConnectionError("original message"), APIConnectionError),
|
||||
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_lost_closes_connection_and_logs(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
exception, raised_exception = exception_map
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
protocol.connection_lost(OSError("original message"))
|
||||
protocol.connection_lost(exception)
|
||||
assert conn.is_connected is False
|
||||
assert "original message" in caplog.text
|
||||
with pytest.raises(APIConnectionError, match="original message"):
|
||||
with pytest.raises(raised_exception, match="original message"):
|
||||
await connect_task
|
||||
|
||||
|
||||
|
@ -24,12 +24,16 @@ from aioesphomeapi.api_pb2 import (
|
||||
)
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionCancelledError,
|
||||
APIConnectionError,
|
||||
ConnectionNotEstablishedAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
ReadFailedAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
|
||||
@ -442,7 +446,9 @@ async def test_finish_connection_times_out(
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=200))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with pytest.raises(APIConnectionError, match="Hello timed out"):
|
||||
with pytest.raises(
|
||||
APIConnectionError, match="Timeout waiting for HelloResponse after 30.0s"
|
||||
):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
@ -458,6 +464,8 @@ async def test_finish_connection_times_out(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("Socket error"), HandshakeAPIError),
|
||||
(APIConnectionError, APIConnectionError),
|
||||
(SocketClosedAPIError, SocketClosedAPIError),
|
||||
(asyncio.TimeoutError, TimeoutAPIError),
|
||||
(asyncio.CancelledError, APIConnectionError),
|
||||
],
|
||||
@ -501,6 +509,21 @@ async def test_plaintext_connection_fails_handshake(
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
call_order = []
|
||||
|
||||
def _socket_close_call():
|
||||
call_order.append("socket_close")
|
||||
|
||||
def _frame_helper_close_call():
|
||||
call_order.append("frame_helper_close")
|
||||
|
||||
async def _do_finish_connect(self, *args, **kwargs):
|
||||
try:
|
||||
await conn._connect_init_frame_helper()
|
||||
finally:
|
||||
conn._socket.close = _socket_close_call
|
||||
conn._frame_helper.close = _frame_helper_close_call
|
||||
|
||||
with (
|
||||
patch(
|
||||
"aioesphomeapi.connection.APIPlaintextFrameHelper",
|
||||
@ -513,42 +536,12 @@ async def test_plaintext_connection_fails_handshake(
|
||||
_create_failing_mock_transport_protocol, transport, connected
|
||||
),
|
||||
),
|
||||
patch.object(conn, "_do_finish_connect", _do_finish_connect),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol = conn._frame_helper
|
||||
assert conn._socket is not None
|
||||
assert conn._frame_helper is not None
|
||||
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||
)
|
||||
mock_data_received(protocol, b"5stackatomproxy")
|
||||
mock_data_received(protocol, b"\x00\x00$")
|
||||
mock_data_received(protocol, b"\x00\x00\x04")
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
||||
)
|
||||
mock_data_received(
|
||||
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
|
||||
call_order = []
|
||||
|
||||
def _socket_close_call():
|
||||
call_order.append("socket_close")
|
||||
|
||||
def _frame_helper_close_call():
|
||||
call_order.append("frame_helper_close")
|
||||
|
||||
with (
|
||||
patch.object(conn._socket, "close", side_effect=_socket_close_call),
|
||||
patch.object(conn._frame_helper, "close", side_effect=_frame_helper_close_call),
|
||||
pytest.raises(raised_exception),
|
||||
):
|
||||
with (pytest.raises(raised_exception),):
|
||||
await asyncio.sleep(0)
|
||||
await connect_task
|
||||
|
||||
@ -556,10 +549,6 @@ async def test_plaintext_connection_fails_handshake(
|
||||
# so asyncio releases the socket
|
||||
assert call_order == ["frame_helper_close", "socket_close"]
|
||||
assert not conn.is_connected
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], HelloResponse)
|
||||
assert isinstance(messages[1], DeviceInfoResponse)
|
||||
assert messages[1].name == "m5stackatomproxy"
|
||||
remove()
|
||||
conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
@ -655,6 +644,110 @@ async def test_force_disconnect_fails(
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("original message"), ReadFailedAPIError),
|
||||
(APIConnectionError("original message"), APIConnectionError),
|
||||
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_lost_while_connecting(
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
exception, raised_exception = exception_map
|
||||
protocol.connection_lost(exception)
|
||||
|
||||
with pytest.raises(raised_exception, match="original message"):
|
||||
await connect_task
|
||||
|
||||
assert not conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("original message"), SocketAPIError),
|
||||
(APIConnectionError("original message"), APIConnectionError),
|
||||
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_during_hello(
|
||||
conn: APIConnection,
|
||||
resolve_host,
|
||||
aiohappyeyeballs_start_connection,
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
exception, raised_exception = exception_map
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
),
|
||||
patch.object(conn, "_connect_hello_login", side_effect=exception),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
with pytest.raises(raised_exception, match="original message"):
|
||||
await connect_task
|
||||
|
||||
assert not conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("original message"), APIConnectionCancelledError),
|
||||
(APIConnectionError("original message"), APIConnectionError),
|
||||
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_cancelled_during_hello(
|
||||
conn: APIConnection,
|
||||
resolve_host,
|
||||
aiohappyeyeballs_start_connection,
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
exception, raised_exception = exception_map
|
||||
|
||||
async def _mock_frame_helper_error(*args, **kwargs):
|
||||
conn._frame_helper.connection_lost(exception)
|
||||
raise asyncio.CancelledError
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
),
|
||||
patch.object(conn, "_connect_hello_login", _mock_frame_helper_error),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
with pytest.raises(raised_exception, match="original message"):
|
||||
await connect_task
|
||||
|
||||
assert not conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resolver_times_out(
|
||||
conn: APIConnection, aiohappyeyeballs_start_connection
|
||||
@ -814,7 +907,7 @@ async def test_ping_disconnects_after_no_responses(
|
||||
start_time
|
||||
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
||||
)
|
||||
assert transport.write.call_count == max_pings_to_disconnect_after
|
||||
assert transport.write.call_count == max_pings_to_disconnect_after + 1
|
||||
|
||||
assert conn.is_connected is False
|
||||
|
||||
|
@ -252,6 +252,11 @@ async def test_log_runner_reconnects_on_subscribe_failure(
|
||||
|
||||
stop_task = asyncio.create_task(stop())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
send_plaintext_hello(protocol)
|
||||
|
||||
disconnect_response = DisconnectResponse()
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
|
||||
|
||||
await stop_task
|
||||
|
Loading…
Reference in New Issue
Block a user