Ensure connection is released if connecting is cancelled

This commit is contained in:
J. Nick Koston 2023-11-30 06:15:20 -10:00
parent 7d92b7974d
commit a5f81c7cda
No known key found for this signature in database
2 changed files with 33 additions and 1 deletions

View File

@ -349,7 +349,7 @@ class APIClient:
"""Execute a coroutine and reset the _connection if it fails."""
try:
await coro
except Exception: # pylint: disable=broad-except
except (Exception, asyncio.CancelledError): # pylint: disable=broad-except
self._connection = None
raise

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from functools import partial
@ -212,6 +213,37 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
await cli.finish_connection(False)
@pytest.mark.asyncio
async def test_connection_released_if_connecting_is_cancelled() -> None:
"""Verify connection is unset if connecting is cancelled."""
cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
assert cli._connection is not None
start_task.cancel()
with contextlib.suppress(BaseException):
await start_task
assert cli._connection is None
with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
await cli.start_connection()
await asyncio.sleep(0)
assert cli._connection is not None
task = asyncio.create_task(cli.finish_connection(False))
await asyncio.sleep(0)
task.cancel()
with contextlib.suppress(BaseException):
await task
assert cli._connection is None
@pytest.mark.asyncio
async def test_request_while_handshaking(event_loop) -> None:
"""Test trying a request while handshaking raises."""