Fix login error cleanup (#126)

This commit is contained in:
Otto Winter 2021-10-21 19:20:05 +02:00 committed by GitHub
parent 1402ecaabc
commit 3b8b2d9d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 16 deletions

View File

@ -160,26 +160,17 @@ class APIClient:
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!")
connected = False
stopped = False
async def _on_stop() -> None:
nonlocal stopped
if stopped:
return
stopped = True
# Hook into on_stop handler to clear connection when stopped
self._connection = None
if connected and on_stop is not None:
if on_stop is not None:
await on_stop()
self._connection = APIConnection(self._params, _on_stop)
self._connection.log_name = self._log_name
try:
await self._connection.connect()
if login:
await self._connection.login()
await self._connection.connect(login=login)
except APIConnectionError:
await _on_stop()
raise
@ -189,8 +180,6 @@ class APIClient:
f"Unexpected error while connecting to {self._log_name}: {e}"
) from e
connected = True
async def disconnect(self, force: bool = False) -> None:
if self._connection is None:
return

View File

@ -249,7 +249,7 @@ class APIConnection:
asyncio.create_task(func())
async def connect(self) -> None:
async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED:
raise ValueError(
"Connection can only be used once, connection is not in init state"
@ -261,6 +261,8 @@ class APIConnection:
await self._connect_init_frame_helper()
await self._connect_hello()
await self._connect_start_ping()
if login:
await self.login()
except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occured during connect
self._connection_state = ConnectionState.CLOSED

View File

@ -58,7 +58,7 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
), patch.object(
conn, "send_message_await_response", return_value=HelloResponse()
):
await conn.connect()
await conn.connect(login=False)
assert conn.is_connected