diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index 154c7d1..a665ae3 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -316,7 +316,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): async def stop(self) -> None: """Stop the connecting logic background task. Does not disconnect the client.""" - if self._connection_state == ReconnectLogicState.CONNECTING: + if self._connection_state in NOT_YET_CONNECTED_STATES: # If we are still establishing a connection, we can safely # cancel the connect task here, otherwise we need to wait # for the connect task to finish so we can gracefully diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index e3f7842..3d81431 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -401,3 +401,52 @@ async def test_reconnect_logic_stop_callback(): await asyncio.sleep(0) assert rl._is_stopped is True assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + +@pytest.mark.asyncio +async def test_reconnect_logic_stop_callback_waits_for_handshake(): + """Test that the stop_callback waits for a handshake.""" + + class PatchableAPIClient(APIClient): + pass + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + rl = ReconnectLogic( + client=cli, + on_disconnect=AsyncMock(), + on_connect=AsyncMock(), + zeroconf_instance=_get_mock_zeroconf(), + name="mydevice", + ) + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + async def slow_connect_fail(*args, **kwargs): + await asyncio.sleep(10) + raise APIConnectionError + + with patch.object(cli, "start_connection"), patch.object( + cli, "finish_connection", side_effect=slow_connect_fail + ): + await rl.start() + for _ in range(3): + await asyncio.sleep(0) + + assert rl._connection_state is ReconnectLogicState.HANDSHAKING + assert rl._is_stopped is False + rl.stop_callback() + # Wait for cancellation to propagate + for _ in range(4): + await asyncio.sleep(0) + assert rl._is_stopped is False + assert rl._connection_state is ReconnectLogicState.HANDSHAKING + + rl._cancel_connect("forced cancel in test") + # Wait for cancellation to propagate + for _ in range(4): + await asyncio.sleep(0) + assert rl._is_stopped is True + assert rl._connection_state is ReconnectLogicState.DISCONNECTED