Ensure connection is released if connecting is cancelled (#784)

This commit is contained in:
J. Nick Koston 2023-11-30 09:17:08 -07:00 committed by GitHub
parent 7d92b7974d
commit e1ddf270c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 1 deletions

View File

@ -349,7 +349,7 @@ class APIClient:
"""Execute a coroutine and reset the _connection if it fails.""" """Execute a coroutine and reset the _connection if it fails."""
try: try:
await coro await coro
except Exception: # pylint: disable=broad-except except (Exception, asyncio.CancelledError): # pylint: disable=broad-except
self._connection = None self._connection = None
raise raise

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import itertools import itertools
import logging import logging
from functools import partial from functools import partial
@ -212,6 +213,37 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
await cli.finish_connection(False) await cli.finish_connection(False)
@pytest.mark.asyncio
async def test_connection_released_if_connecting_is_cancelled() -> None:
"""Verify connection is unset if connecting is cancelled."""
cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
assert cli._connection is not None
start_task.cancel()
with contextlib.suppress(BaseException):
await start_task
assert cli._connection is None
with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
await cli.start_connection()
await asyncio.sleep(0)
assert cli._connection is not None
task = asyncio.create_task(cli.finish_connection(False))
await asyncio.sleep(0)
task.cancel()
with contextlib.suppress(BaseException):
await task
assert cli._connection is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_while_handshaking(event_loop) -> None: async def test_request_while_handshaking(event_loop) -> None:
"""Test trying a request while handshaking raises.""" """Test trying a request while handshaking raises."""