diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index ddbf736..ff4db77 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -349,7 +349,7 @@ class APIClient: """Execute a coroutine and reset the _connection if it fails.""" try: await coro - except Exception: # pylint: disable=broad-except + except (Exception, asyncio.CancelledError): # pylint: disable=broad-except self._connection = None raise diff --git a/tests/test_client.py b/tests/test_client.py index 92a6440..e6251c1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import itertools import logging from functools import partial @@ -212,6 +213,37 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No await cli.finish_connection(False) +@pytest.mark.asyncio +async def test_connection_released_if_connecting_is_cancelled() -> None: + """Verify connection is unset if connecting is cancelled.""" + cli = APIClient("1.2.3.4", 1234, None) + loop = asyncio.get_event_loop() + + with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)): + start_task = asyncio.create_task(cli.start_connection()) + await asyncio.sleep(0) + assert cli._connection is not None + + start_task.cancel() + with contextlib.suppress(BaseException): + await start_task + assert cli._connection is None + + with patch( + "aioesphomeapi.client.APIConnection", PatchableAPIConnection + ), patch.object(loop, "sock_connect"): + await cli.start_connection() + await asyncio.sleep(0) + + assert cli._connection is not None + task = asyncio.create_task(cli.finish_connection(False)) + await asyncio.sleep(0) + task.cancel() + with contextlib.suppress(BaseException): + await task + assert cli._connection is None + + @pytest.mark.asyncio async def test_request_while_handshaking(event_loop) -> None: """Test trying a request while handshaking raises."""