Allow the stop callback to be cancelled when already disconnected (#615)

This commit is contained in:
J. Nick Koston 2023-11-06 18:17:50 -06:00 committed by GitHub
parent a0239b7a63
commit 6458ebcf60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 1 deletions

View File

@ -316,7 +316,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async def stop(self) -> None:
"""Stop the connecting logic background task. Does not disconnect the client."""
if self._connection_state == ReconnectLogicState.CONNECTING:
if self._connection_state in NOT_YET_CONNECTED_STATES:
# If we are still establishing a connection, we can safely
# cancel the connect task here, otherwise we need to wait
# for the connect task to finish so we can gracefully

View File

@ -401,3 +401,52 @@ async def test_reconnect_logic_stop_callback():
await asyncio.sleep(0)
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback_waits_for_handshake():
"""Test that the stop_callback waits for a handshake."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
):
await rl.start()
for _ in range(3):
await asyncio.sleep(0)
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
assert rl._is_stopped is False
rl.stop_callback()
# Wait for cancellation to propagate
for _ in range(4):
await asyncio.sleep(0)
assert rl._is_stopped is False
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
rl._cancel_connect("forced cancel in test")
# Wait for cancellation to propagate
for _ in range(4):
await asyncio.sleep(0)
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED