mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-02 18:38:05 +01:00
Avoid creating tasks for starting/finishing the connection (#826)
This commit is contained in:
parent
939c5296e3
commit
e2bbbf4da5
@ -98,8 +98,8 @@ cdef class APIConnection:
|
|||||||
cdef object _pong_timer
|
cdef object _pong_timer
|
||||||
cdef float _keep_alive_interval
|
cdef float _keep_alive_interval
|
||||||
cdef float _keep_alive_timeout
|
cdef float _keep_alive_timeout
|
||||||
cdef object _start_connect_task
|
cdef object _start_connect_future
|
||||||
cdef object _finish_connect_task
|
cdef object _finish_connect_future
|
||||||
cdef public Exception _fatal_exception
|
cdef public Exception _fatal_exception
|
||||||
cdef bint _expected_disconnect
|
cdef bint _expected_disconnect
|
||||||
cdef object _loop
|
cdef object _loop
|
||||||
@ -154,3 +154,7 @@ cdef class APIConnection:
|
|||||||
cdef void _register_internal_message_handlers(self)
|
cdef void _register_internal_message_handlers(self)
|
||||||
|
|
||||||
cdef void _increase_recv_buffer_size(self)
|
cdef void _increase_recv_buffer_size(self)
|
||||||
|
|
||||||
|
cdef void _set_start_connect_future(self)
|
||||||
|
|
||||||
|
cdef void _set_finish_connect_future(self)
|
||||||
|
@ -16,6 +16,7 @@ from functools import lru_cache, partial
|
|||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
import aiohappyeyeballs
|
import aiohappyeyeballs
|
||||||
|
from async_interrupt import interrupt
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
|
|
||||||
import aioesphomeapi.host_resolver as hr
|
import aioesphomeapi.host_resolver as hr
|
||||||
@ -106,6 +107,10 @@ _bytes = bytes
|
|||||||
_float = float
|
_float = float
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionInterruptedError(Exception):
|
||||||
|
"""An error that is raised when a connection is interrupted."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams:
|
class ConnectionParams:
|
||||||
addresses: list[str]
|
addresses: list[str]
|
||||||
@ -198,8 +203,8 @@ class APIConnection:
|
|||||||
"_pong_timer",
|
"_pong_timer",
|
||||||
"_keep_alive_interval",
|
"_keep_alive_interval",
|
||||||
"_keep_alive_timeout",
|
"_keep_alive_timeout",
|
||||||
"_start_connect_task",
|
"_start_connect_future",
|
||||||
"_finish_connect_task",
|
"_finish_connect_future",
|
||||||
"_fatal_exception",
|
"_fatal_exception",
|
||||||
"_expected_disconnect",
|
"_expected_disconnect",
|
||||||
"_loop",
|
"_loop",
|
||||||
@ -242,8 +247,8 @@ class APIConnection:
|
|||||||
self._keep_alive_interval = keepalive
|
self._keep_alive_interval = keepalive
|
||||||
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||||
|
|
||||||
self._start_connect_task: asyncio.Task[None] | None = None
|
self._start_connect_future: asyncio.Future[None] | None = None
|
||||||
self._finish_connect_task: asyncio.Task[None] | None = None
|
self._finish_connect_future: asyncio.Future[None] | None = None
|
||||||
self._fatal_exception: Exception | None = None
|
self._fatal_exception: Exception | None = None
|
||||||
self._expected_disconnect = False
|
self._expected_disconnect = False
|
||||||
self._send_pending_ping = False
|
self._send_pending_ping = False
|
||||||
@ -276,28 +281,13 @@ class APIConnection:
|
|||||||
err = self._fatal_exception or APIConnectionError("Connection closed")
|
err = self._fatal_exception or APIConnectionError("Connection closed")
|
||||||
new_exc = err
|
new_exc = err
|
||||||
if not isinstance(err, APIConnectionError):
|
if not isinstance(err, APIConnectionError):
|
||||||
new_exc = ReadFailedAPIError("Read failed")
|
new_exc = ReadFailedAPIError(str(err) or "Read failed")
|
||||||
new_exc.__cause__ = err
|
new_exc.__cause__ = err
|
||||||
fut.set_exception(new_exc)
|
fut.set_exception(new_exc)
|
||||||
self._read_exception_futures.clear()
|
self._read_exception_futures.clear()
|
||||||
# If we are being called from do_connect we
|
|
||||||
# need to make sure we don't cancel the task
|
|
||||||
# that called us
|
|
||||||
current_task = asyncio.current_task()
|
|
||||||
|
|
||||||
if (
|
self._set_start_connect_future()
|
||||||
self._start_connect_task is not None
|
self._set_finish_connect_future()
|
||||||
and self._start_connect_task is not current_task
|
|
||||||
):
|
|
||||||
self._start_connect_task.cancel("Connection cleanup")
|
|
||||||
self._start_connect_task = None
|
|
||||||
|
|
||||||
if (
|
|
||||||
self._finish_connect_task is not None
|
|
||||||
and self._finish_connect_task is not current_task
|
|
||||||
):
|
|
||||||
self._finish_connect_task.cancel("Connection cleanup")
|
|
||||||
self._finish_connect_task = None
|
|
||||||
|
|
||||||
if self._frame_helper is not None:
|
if self._frame_helper is not None:
|
||||||
self._frame_helper.close()
|
self._frame_helper.close()
|
||||||
@ -460,7 +450,9 @@ class APIConnection:
|
|||||||
try:
|
try:
|
||||||
await self._frame_helper.ready_future
|
await self._frame_helper.ready_future
|
||||||
except asyncio_TimeoutError as err:
|
except asyncio_TimeoutError as err:
|
||||||
raise TimeoutAPIError("Handshake timed out") from err
|
raise TimeoutAPIError(
|
||||||
|
f"Handshake timed out after {HANDSHAKE_TIMEOUT}s"
|
||||||
|
) from err
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||||
finally:
|
finally:
|
||||||
@ -475,19 +467,14 @@ class APIConnection:
|
|||||||
messages.append(self._make_connect_request())
|
messages.append(self._make_connect_request())
|
||||||
msg_types.append(ConnectResponse)
|
msg_types.append(ConnectResponse)
|
||||||
|
|
||||||
try:
|
responses = await self.send_messages_await_response_complex(
|
||||||
responses = await self.send_messages_await_response_complex(
|
tuple(messages),
|
||||||
tuple(messages),
|
None,
|
||||||
None,
|
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
|
||||||
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
|
is msg_types[-1],
|
||||||
is msg_types[-1],
|
tuple(msg_types),
|
||||||
tuple(msg_types),
|
CONNECT_REQUEST_TIMEOUT,
|
||||||
CONNECT_REQUEST_TIMEOUT,
|
)
|
||||||
)
|
|
||||||
except TimeoutAPIError as err:
|
|
||||||
self.report_fatal_error(err)
|
|
||||||
raise TimeoutAPIError("Hello timed out") from err
|
|
||||||
|
|
||||||
resp = responses.pop(0)
|
resp = responses.pop(0)
|
||||||
self._process_hello_resp(resp)
|
self._process_hello_resp(resp)
|
||||||
if login:
|
if login:
|
||||||
@ -605,21 +592,29 @@ class APIConnection:
|
|||||||
"Connection can only be used once, connection is not in init state"
|
"Connection can only be used once, connection is not in init state"
|
||||||
)
|
)
|
||||||
|
|
||||||
start_connect_task = asyncio.create_task(
|
self._start_connect_future = self._loop.create_future()
|
||||||
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
|
||||||
)
|
|
||||||
self._start_connect_task = start_connect_task
|
|
||||||
try:
|
try:
|
||||||
await start_connect_task
|
async with interrupt(
|
||||||
|
self._start_connect_future, ConnectionInterruptedError, None
|
||||||
|
):
|
||||||
|
await self._do_connect()
|
||||||
except (Exception, CancelledError) as ex:
|
except (Exception, CancelledError) as ex:
|
||||||
# If the task was cancelled, we need to clean up the connection
|
# If the task was cancelled, we need to clean up the connection
|
||||||
# and raise the CancelledError as APIConnectionError
|
# and raise the CancelledError as APIConnectionError
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
raise self._wrap_fatal_connection_exception("starting", ex)
|
raise self._wrap_fatal_connection_exception("starting", ex)
|
||||||
finally:
|
finally:
|
||||||
self._start_connect_task = None
|
self._set_start_connect_future()
|
||||||
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
|
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
|
||||||
|
|
||||||
|
def _set_start_connect_future(self) -> None:
|
||||||
|
if (
|
||||||
|
self._start_connect_future is not None
|
||||||
|
and not self._start_connect_future.done()
|
||||||
|
):
|
||||||
|
self._start_connect_future.set_result(None)
|
||||||
|
self._start_connect_future = None
|
||||||
|
|
||||||
def _wrap_fatal_connection_exception(
|
def _wrap_fatal_connection_exception(
|
||||||
self, action: str, ex: BaseException
|
self, action: str, ex: BaseException
|
||||||
) -> APIConnectionError:
|
) -> APIConnectionError:
|
||||||
@ -627,7 +622,7 @@ class APIConnection:
|
|||||||
if isinstance(ex, APIConnectionError):
|
if isinstance(ex, APIConnectionError):
|
||||||
return ex
|
return ex
|
||||||
cause: BaseException | None = None
|
cause: BaseException | None = None
|
||||||
if isinstance(ex, CancelledError):
|
if isinstance(ex, (ConnectionInterruptedError, CancelledError)):
|
||||||
err_str = f"{action.title()} connection cancelled"
|
err_str = f"{action.title()} connection cancelled"
|
||||||
if self._fatal_exception:
|
if self._fatal_exception:
|
||||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||||
@ -664,22 +659,29 @@ class APIConnection:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||||
)
|
)
|
||||||
finish_connect_task = asyncio.create_task(
|
self._finish_connect_future = self._loop.create_future()
|
||||||
self._do_finish_connect(login),
|
|
||||||
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
|
|
||||||
)
|
|
||||||
self._finish_connect_task = finish_connect_task
|
|
||||||
try:
|
try:
|
||||||
await self._finish_connect_task
|
async with interrupt(
|
||||||
|
self._finish_connect_future, ConnectionInterruptedError, None
|
||||||
|
):
|
||||||
|
await self._do_finish_connect(login)
|
||||||
except (Exception, CancelledError) as ex:
|
except (Exception, CancelledError) as ex:
|
||||||
# If the task was cancelled, we need to clean up the connection
|
# If the task was cancelled, we need to clean up the connection
|
||||||
# and raise the CancelledError as APIConnectionError
|
# and raise the CancelledError as APIConnectionError
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
raise self._wrap_fatal_connection_exception("finishing", ex)
|
raise self._wrap_fatal_connection_exception("finishing", ex)
|
||||||
finally:
|
finally:
|
||||||
self._finish_connect_task = None
|
self._set_finish_connect_future()
|
||||||
self._set_connection_state(CONNECTION_STATE_CONNECTED)
|
self._set_connection_state(CONNECTION_STATE_CONNECTED)
|
||||||
|
|
||||||
|
def _set_finish_connect_future(self) -> None:
|
||||||
|
if (
|
||||||
|
self._finish_connect_future is not None
|
||||||
|
and not self._finish_connect_future.done()
|
||||||
|
):
|
||||||
|
self._finish_connect_future.set_result(None)
|
||||||
|
self._finish_connect_future = None
|
||||||
|
|
||||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||||
"""Set the connection state and log the change."""
|
"""Set the connection state and log the change."""
|
||||||
self.connection_state = state
|
self.connection_state = state
|
||||||
@ -969,12 +971,12 @@ class APIConnection:
|
|||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
"""Disconnect from the API."""
|
"""Disconnect from the API."""
|
||||||
if self._finish_connect_task is not None:
|
if self._finish_connect_future is not None:
|
||||||
# Try to wait for the handshake to finish so we can send
|
# Try to wait for the handshake to finish so we can send
|
||||||
# a disconnect request. If it doesn't finish in time
|
# a disconnect request. If it doesn't finish in time
|
||||||
# we will just close the socket.
|
# we will just close the socket.
|
||||||
_, pending = await asyncio.wait(
|
_, pending = await asyncio.wait(
|
||||||
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
|
[self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT
|
||||||
)
|
)
|
||||||
if pending:
|
if pending:
|
||||||
self._set_fatal_exception_if_unset(
|
self._set_fatal_exception_if_unset(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
aiohappyeyeballs>=2.3.0
|
aiohappyeyeballs>=2.3.0
|
||||||
|
async-interrupt>=1.1.1
|
||||||
protobuf>=3.19.0
|
protobuf>=3.19.0
|
||||||
zeroconf>=0.128.4,<1.0
|
zeroconf>=0.128.4,<1.0
|
||||||
chacha20poly1305-reuseable>=0.12.1
|
chacha20poly1305-reuseable>=0.12.1
|
||||||
|
@ -22,6 +22,7 @@ from aioesphomeapi.core import (
|
|||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
|
ReadFailedAPIError,
|
||||||
SocketClosedAPIError,
|
SocketClosedAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -725,18 +726,28 @@ async def test_eof_received_closes_connection(
|
|||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception_map"),
|
||||||
|
[
|
||||||
|
(OSError("original message"), ReadFailedAPIError),
|
||||||
|
(APIConnectionError("original message"), APIConnectionError),
|
||||||
|
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connection_lost_closes_connection_and_logs(
|
async def test_connection_lost_closes_connection_and_logs(
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
plaintext_connect_task_with_login: tuple[
|
plaintext_connect_task_with_login: tuple[
|
||||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||||
],
|
],
|
||||||
|
exception_map: tuple[Exception, Exception],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
exception, raised_exception = exception_map
|
||||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
protocol.connection_lost(OSError("original message"))
|
protocol.connection_lost(exception)
|
||||||
assert conn.is_connected is False
|
assert conn.is_connected is False
|
||||||
assert "original message" in caplog.text
|
assert "original message" in caplog.text
|
||||||
with pytest.raises(APIConnectionError, match="original message"):
|
with pytest.raises(raised_exception, match="original message"):
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,12 +24,16 @@ from aioesphomeapi.api_pb2 import (
|
|||||||
)
|
)
|
||||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||||
from aioesphomeapi.core import (
|
from aioesphomeapi.core import (
|
||||||
|
APIConnectionCancelledError,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
ConnectionNotEstablishedAPIError,
|
ConnectionNotEstablishedAPIError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
|
ReadFailedAPIError,
|
||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
ResolveAPIError,
|
ResolveAPIError,
|
||||||
|
SocketAPIError,
|
||||||
|
SocketClosedAPIError,
|
||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -442,7 +446,9 @@ async def test_finish_connection_times_out(
|
|||||||
async_fire_time_changed(utcnow() + timedelta(seconds=200))
|
async_fire_time_changed(utcnow() + timedelta(seconds=200))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
with pytest.raises(APIConnectionError, match="Hello timed out"):
|
with pytest.raises(
|
||||||
|
APIConnectionError, match="Timeout waiting for HelloResponse after 30.0s"
|
||||||
|
):
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||||
@ -458,6 +464,8 @@ async def test_finish_connection_times_out(
|
|||||||
("exception_map"),
|
("exception_map"),
|
||||||
[
|
[
|
||||||
(OSError("Socket error"), HandshakeAPIError),
|
(OSError("Socket error"), HandshakeAPIError),
|
||||||
|
(APIConnectionError, APIConnectionError),
|
||||||
|
(SocketClosedAPIError, SocketClosedAPIError),
|
||||||
(asyncio.TimeoutError, TimeoutAPIError),
|
(asyncio.TimeoutError, TimeoutAPIError),
|
||||||
(asyncio.CancelledError, APIConnectionError),
|
(asyncio.CancelledError, APIConnectionError),
|
||||||
],
|
],
|
||||||
@ -501,6 +509,21 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
def _socket_close_call():
|
||||||
|
call_order.append("socket_close")
|
||||||
|
|
||||||
|
def _frame_helper_close_call():
|
||||||
|
call_order.append("frame_helper_close")
|
||||||
|
|
||||||
|
async def _do_finish_connect(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
await conn._connect_init_frame_helper()
|
||||||
|
finally:
|
||||||
|
conn._socket.close = _socket_close_call
|
||||||
|
conn._frame_helper.close = _frame_helper_close_call
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"aioesphomeapi.connection.APIPlaintextFrameHelper",
|
"aioesphomeapi.connection.APIPlaintextFrameHelper",
|
||||||
@ -513,42 +536,12 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
_create_failing_mock_transport_protocol, transport, connected
|
_create_failing_mock_transport_protocol, transport, connected
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
patch.object(conn, "_do_finish_connect", _do_finish_connect),
|
||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
|
|
||||||
protocol = conn._frame_helper
|
with (pytest.raises(raised_exception),):
|
||||||
assert conn._socket is not None
|
|
||||||
assert conn._frame_helper is not None
|
|
||||||
|
|
||||||
mock_data_received(
|
|
||||||
protocol,
|
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
|
||||||
)
|
|
||||||
mock_data_received(protocol, b"5stackatomproxy")
|
|
||||||
mock_data_received(protocol, b"\x00\x00$")
|
|
||||||
mock_data_received(protocol, b"\x00\x00\x04")
|
|
||||||
mock_data_received(
|
|
||||||
protocol,
|
|
||||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
|
||||||
)
|
|
||||||
mock_data_received(
|
|
||||||
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
|
||||||
)
|
|
||||||
|
|
||||||
call_order = []
|
|
||||||
|
|
||||||
def _socket_close_call():
|
|
||||||
call_order.append("socket_close")
|
|
||||||
|
|
||||||
def _frame_helper_close_call():
|
|
||||||
call_order.append("frame_helper_close")
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(conn._socket, "close", side_effect=_socket_close_call),
|
|
||||||
patch.object(conn._frame_helper, "close", side_effect=_frame_helper_close_call),
|
|
||||||
pytest.raises(raised_exception),
|
|
||||||
):
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
@ -556,10 +549,6 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
# so asyncio releases the socket
|
# so asyncio releases the socket
|
||||||
assert call_order == ["frame_helper_close", "socket_close"]
|
assert call_order == ["frame_helper_close", "socket_close"]
|
||||||
assert not conn.is_connected
|
assert not conn.is_connected
|
||||||
assert len(messages) == 2
|
|
||||||
assert isinstance(messages[0], HelloResponse)
|
|
||||||
assert isinstance(messages[1], DeviceInfoResponse)
|
|
||||||
assert messages[1].name == "m5stackatomproxy"
|
|
||||||
remove()
|
remove()
|
||||||
conn.force_disconnect()
|
conn.force_disconnect()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@ -655,6 +644,110 @@ async def test_force_disconnect_fails(
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception_map"),
|
||||||
|
[
|
||||||
|
(OSError("original message"), ReadFailedAPIError),
|
||||||
|
(APIConnectionError("original message"), APIConnectionError),
|
||||||
|
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connection_lost_while_connecting(
|
||||||
|
plaintext_connect_task_with_login: tuple[
|
||||||
|
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||||
|
],
|
||||||
|
exception_map: tuple[Exception, Exception],
|
||||||
|
) -> None:
|
||||||
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
|
|
||||||
|
exception, raised_exception = exception_map
|
||||||
|
protocol.connection_lost(exception)
|
||||||
|
|
||||||
|
with pytest.raises(raised_exception, match="original message"):
|
||||||
|
await connect_task
|
||||||
|
|
||||||
|
assert not conn.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception_map"),
|
||||||
|
[
|
||||||
|
(OSError("original message"), SocketAPIError),
|
||||||
|
(APIConnectionError("original message"), APIConnectionError),
|
||||||
|
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connection_error_during_hello(
|
||||||
|
conn: APIConnection,
|
||||||
|
resolve_host,
|
||||||
|
aiohappyeyeballs_start_connection,
|
||||||
|
exception_map: tuple[Exception, Exception],
|
||||||
|
) -> None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
exception, raised_exception = exception_map
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
|
),
|
||||||
|
patch.object(conn, "_connect_hello_login", side_effect=exception),
|
||||||
|
):
|
||||||
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
|
await connected.wait()
|
||||||
|
|
||||||
|
with pytest.raises(raised_exception, match="original message"):
|
||||||
|
await connect_task
|
||||||
|
|
||||||
|
assert not conn.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception_map"),
|
||||||
|
[
|
||||||
|
(OSError("original message"), APIConnectionCancelledError),
|
||||||
|
(APIConnectionError("original message"), APIConnectionError),
|
||||||
|
(SocketClosedAPIError("original message"), SocketClosedAPIError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connection_cancelled_during_hello(
|
||||||
|
conn: APIConnection,
|
||||||
|
resolve_host,
|
||||||
|
aiohappyeyeballs_start_connection,
|
||||||
|
exception_map: tuple[Exception, Exception],
|
||||||
|
) -> None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
exception, raised_exception = exception_map
|
||||||
|
|
||||||
|
async def _mock_frame_helper_error(*args, **kwargs):
|
||||||
|
conn._frame_helper.connection_lost(exception)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
|
),
|
||||||
|
patch.object(conn, "_connect_hello_login", _mock_frame_helper_error),
|
||||||
|
):
|
||||||
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
|
await connected.wait()
|
||||||
|
|
||||||
|
with pytest.raises(raised_exception, match="original message"):
|
||||||
|
await connect_task
|
||||||
|
|
||||||
|
assert not conn.is_connected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect_resolver_times_out(
|
async def test_connect_resolver_times_out(
|
||||||
conn: APIConnection, aiohappyeyeballs_start_connection
|
conn: APIConnection, aiohappyeyeballs_start_connection
|
||||||
@ -814,7 +907,7 @@ async def test_ping_disconnects_after_no_responses(
|
|||||||
start_time
|
start_time
|
||||||
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
||||||
)
|
)
|
||||||
assert transport.write.call_count == max_pings_to_disconnect_after
|
assert transport.write.call_count == max_pings_to_disconnect_after + 1
|
||||||
|
|
||||||
assert conn.is_connected is False
|
assert conn.is_connected is False
|
||||||
|
|
||||||
|
@ -252,6 +252,11 @@ async def test_log_runner_reconnects_on_subscribe_failure(
|
|||||||
|
|
||||||
stop_task = asyncio.create_task(stop())
|
stop_task = asyncio.create_task(stop())
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
|
||||||
disconnect_response = DisconnectResponse()
|
disconnect_response = DisconnectResponse()
|
||||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
|
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
|
||||||
|
|
||||||
await stop_task
|
await stop_task
|
||||||
|
Loading…
Reference in New Issue
Block a user