Refactor to avoid creating tasks for starting/finishing the connection

This commit is contained in:
J. Nick Koston 2024-02-16 16:34:11 -06:00
parent 4d93e694e8
commit 30360de64e
No known key found for this signature in database
1 changed files with 22 additions and 6 deletions

View File

@ -286,11 +286,17 @@ class APIConnection:
fut.set_exception(new_exc)
self._read_exception_futures.clear()
if self._start_connect_future is not None:
if (
self._start_connect_future is not None
and not self._start_connect_future.done()
):
self._start_connect_future.set_result(None)
self._start_connect_future = None
if self._finish_connect_future is not None:
if (
self._finish_connect_future is not None
and not self._finish_connect_future.done()
):
self._finish_connect_future.set_result(None)
self._finish_connect_future = None
@ -612,7 +618,12 @@ class APIConnection:
self._cleanup()
raise self._wrap_fatal_connection_exception("starting", ex)
finally:
self._start_connect_future = None
if (
self._start_connect_future is not None
and not self._start_connect_future.done()
):
self._start_connect_future.set_result(None)
self._start_connect_future = None
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
def _wrap_fatal_connection_exception(
@ -671,7 +682,12 @@ class APIConnection:
self._cleanup()
raise self._wrap_fatal_connection_exception("finishing", ex)
finally:
self._finish_connect_future = None
if (
self._finish_connect_future is not None
and not self._finish_connect_future.done()
):
self._finish_connect_future.set_result(None)
self._finish_connect_future = None
self._set_connection_state(CONNECTION_STATE_CONNECTED)
def _set_connection_state(self, state: ConnectionState) -> None:
@ -963,12 +979,12 @@ class APIConnection:
async def disconnect(self) -> None:
"""Disconnect from the API."""
if self._finish_connect_task is not None:
if self._finish_connect_future is not None:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.
_, pending = await asyncio.wait(
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
[self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT
)
if pending:
self._set_fatal_exception_if_unset(