mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Ensure connection is released if connecting is cancelled (#784)
This commit is contained in:
parent
7d92b7974d
commit
e1ddf270c5
@ -349,7 +349,7 @@ class APIClient:
|
||||
"""Execute a coroutine and reset the _connection if it fails."""
|
||||
try:
|
||||
await coro
|
||||
except Exception: # pylint: disable=broad-except
|
||||
except (Exception, asyncio.CancelledError): # pylint: disable=broad-except
|
||||
self._connection = None
|
||||
raise
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
from functools import partial
|
||||
@ -212,6 +213,37 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
|
||||
await cli.finish_connection(False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_released_if_connecting_is_cancelled() -> None:
|
||||
"""Verify connection is unset if connecting is cancelled."""
|
||||
cli = APIClient("1.2.3.4", 1234, None)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
|
||||
start_task = asyncio.create_task(cli.start_connection())
|
||||
await asyncio.sleep(0)
|
||||
assert cli._connection is not None
|
||||
|
||||
start_task.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await start_task
|
||||
assert cli._connection is None
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
|
||||
), patch.object(loop, "sock_connect"):
|
||||
await cli.start_connection()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert cli._connection is not None
|
||||
task = asyncio.create_task(cli.finish_connection(False))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await task
|
||||
assert cli._connection is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_while_handshaking(event_loop) -> None:
|
||||
"""Test trying a request while handshaking raises."""
|
||||
|
Loading…
Reference in New Issue
Block a user