diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index bb59e9d..2832c61 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -226,11 +226,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): def _schedule_connect(self, delay: float) -> None: """Schedule a connect attempt.""" - self._cancel_connect("Scheduling new connect attempt") if not delay: self._call_connect_once() return _LOGGER.debug("Scheduling new connect attempt in %f seconds", delay) + self._cancel_connect_timer() self._connect_timer = self.loop.call_at( self.loop.time() + delay, self._call_connect_once ) @@ -240,17 +240,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): Must only be called from _schedule_connect. """ - if self._connect_task: + if self._connect_task and not self._connect_task.done(): if self._connection_state != ReconnectLogicState.CONNECTING: # Connection state is far enough along that we should # not restart the connect task + _LOGGER.debug( + "%s: Not cancelling existing connect task as its already %s!", + self._cli.log_name, + self._connection_state, + ) return _LOGGER.debug( - "%s: Cancelling existing connect task, to try again now!", + "%s: Cancelling existing connect task with state %s, to try again now!", self._cli.log_name, + self._connection_state, ) - self._connect_task.cancel("Scheduling new connect attempt") - self._connect_task = None + self._cancel_connect_task("Scheduling new connect attempt") self._async_set_connection_state_without_lock( ReconnectLogicState.DISCONNECTED ) @@ -260,15 +265,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): name=f"{self._cli.log_name}: aioesphomeapi connect", ) - def _cancel_connect(self, msg: str) -> None: - """Cancel the connect.""" + def _cancel_connect_timer(self) -> None: + """Cancel the connect timer.""" if self._connect_timer: self._connect_timer.cancel() self._connect_timer = None + + def _cancel_connect_task(self, msg: str) -> None: + """Cancel the connect task.""" if self._connect_task: self._connect_task.cancel(msg) self._connect_task = None + def _cancel_connect(self, msg: str) -> None: + """Cancel the connect.""" + self._cancel_connect_timer() + self._cancel_connect_task(msg) + async def _connect_once_or_reschedule(self) -> None: """Connect once or schedule connect. diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index 6b20a90..7236c58 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -38,6 +38,15 @@ from .conftest import _create_mock_transport_protocol logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG) +async def slow_connect_fail(*args, **kwargs): + await asyncio.sleep(10) + raise APIConnectionError + + +async def quick_connect_fail(*args, **kwargs): + raise APIConnectionError + + @pytest.mark.asyncio async def test_reconnect_logic_name_from_host(): """Test that the name is set correctly from the host.""" @@ -288,23 +297,33 @@ async def test_reconnect_retry( assert len(on_connect_called) == 1 assert len(on_connect_fail_called) == 2 assert rl._connection_state is ReconnectLogicState.READY + original_when = rl._connect_timer.when() + + # Ensure starting the connection logic again does not trigger a new connection + await rl.start() + # Verify no new timer is started + assert rl._connect_timer.when() == original_when await rl.stop() assert rl._connection_state is ReconnectLogicState.DISCONNECTED +DNS_POINTER = DNSPointer( + "_esphomelib._tcp.local.", + _TYPE_PTR, + _CLASS_IN, + 1000, + "mydevice._esphomelib._tcp.local.", +) + + @pytest.mark.parametrize( - ("record", "should_trigger_zeroconf", "log_text"), + ("record", "should_trigger_zeroconf", "expected_state_after_trigger", "log_text"), ( ( - DNSPointer( - "_esphomelib._tcp.local.", - _TYPE_PTR, - _CLASS_IN, - 1000, - "mydevice._esphomelib._tcp.local.", - ), + DNS_POINTER, True, + ReconnectLogicState.READY, "received mDNS record", ), ( @@ -316,6 +335,7 @@ async def test_reconnect_retry( "wrong_name._esphomelib._tcp.local.", ), False, + ReconnectLogicState.CONNECTING, "", ), ( @@ -327,27 +347,23 @@ async def test_reconnect_retry( ip_address("1.2.3.4").packed, ), True, + ReconnectLogicState.READY, "received mDNS record", ), ), ) @pytest.mark.asyncio async def test_reconnect_zeroconf( + patchable_api_client: APIClient, caplog: pytest.LogCaptureFixture, record: DNSRecord, should_trigger_zeroconf: bool, + expected_state_after_trigger: ReconnectLogicState, log_text: str, ) -> None: """Test that reconnect logic retry.""" - class PatchableAPIClient(APIClient): - pass - - cli = PatchableAPIClient( - address="1.2.3.4", - port=6052, - password=None, - ) + cli = patchable_api_client mock_zeroconf = MagicMock(spec=Zeroconf) @@ -361,13 +377,6 @@ async def test_reconnect_zeroconf( ) assert cli.log_name == "mydevice @ 1.2.3.4" - async def slow_connect_fail(*args, **kwargs): - await asyncio.sleep(10) - raise APIConnectionError - - async def quick_connect_fail(*args, **kwargs): - raise APIConnectionError - with patch.object( cli, "start_connection", side_effect=quick_connect_fail ) as mock_start_connection: @@ -379,17 +388,150 @@ async def test_reconnect_zeroconf( with patch.object( cli, "start_connection", side_effect=slow_connect_fail ) as mock_start_connection: + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + assert rl._accept_zeroconf_records is True + assert not rl._is_stopped + + assert rl._connect_timer is not None + rl._connect_timer._run() await asyncio.sleep(0) + assert mock_start_connection.call_count == 1 + assert rl._connection_state is ReconnectLogicState.CONNECTING + assert rl._accept_zeroconf_records is True + assert not rl._is_stopped - assert mock_start_connection.call_count == 0 - + with patch.object(cli, "start_connection") as mock_start_connection, patch.object( + cli, "finish_connection" + ): + assert rl._zc_listening is True rl.async_update_records( mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)] ) + assert ( + "Triggering connect because of received mDNS record" in caplog.text + ) is should_trigger_zeroconf + assert rl._zc_listening is True # should change after one iteration of the loop await asyncio.sleep(0) + assert rl._zc_listening is not should_trigger_zeroconf + assert mock_start_connection.call_count == int(should_trigger_zeroconf) assert log_text in caplog.text + assert rl._connection_state is expected_state_after_trigger + await rl.stop() + assert rl._is_stopped is True + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + +@pytest.mark.asyncio +async def test_reconnect_zeroconf_not_while_handshaking( + patchable_api_client: APIClient, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that reconnect logic retry will not trigger a zeroconf reconnect while handshaking.""" + cli = patchable_api_client + + mock_zeroconf = MagicMock(spec=Zeroconf) + + rl = ReconnectLogic( + client=cli, + on_disconnect=AsyncMock(), + on_connect=AsyncMock(), + zeroconf_instance=mock_zeroconf, + name="mydevice", + on_connect_error=AsyncMock(), + ) + assert cli.log_name == "mydevice @ 1.2.3.4" + + with patch.object( + cli, "start_connection", side_effect=quick_connect_fail + ) as mock_start_connection: + await rl.start() + await asyncio.sleep(0) + + assert mock_start_connection.call_count == 1 + + with patch.object(cli, "start_connection") as mock_start_connection, patch.object( + cli, "finish_connection", side_effect=slow_connect_fail + ) as mock_finish_connection: + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + assert rl._accept_zeroconf_records is True + assert not rl._is_stopped + + assert rl._connect_timer is not None + rl._connect_timer._run() + await asyncio.sleep(0) + assert mock_start_connection.call_count == 1 + assert mock_finish_connection.call_count == 1 + assert rl._connection_state is ReconnectLogicState.HANDSHAKING + assert rl._accept_zeroconf_records is False + assert not rl._is_stopped + + rl.async_update_records( + mock_zeroconf, current_time_millis(), [RecordUpdate(DNS_POINTER, None)] + ) + assert ( + "Triggering connect because of received mDNS record" in caplog.text + ) is False + + rl._cancel_connect("forced cancel in test") + await rl.stop() + assert rl._is_stopped is True + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + +@pytest.mark.asyncio +async def test_connect_task_not_cancelled_while_handshaking( + patchable_api_client: APIClient, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that reconnect logic will not cancel an in progress handshake.""" + cli = patchable_api_client + + rl = ReconnectLogic( + client=cli, + on_disconnect=AsyncMock(), + on_connect=AsyncMock(), + name="mydevice", + on_connect_error=AsyncMock(), + ) + assert cli.log_name == "mydevice @ 1.2.3.4" + + with patch.object( + cli, "start_connection", side_effect=quick_connect_fail + ) as mock_start_connection: + await rl.start() + await asyncio.sleep(0) + + assert mock_start_connection.call_count == 1 + + with patch.object(cli, "start_connection") as mock_start_connection, patch.object( + cli, "finish_connection", side_effect=slow_connect_fail + ) as mock_finish_connection: + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + assert rl._accept_zeroconf_records is True + assert not rl._is_stopped + + assert rl._connect_timer is not None + rl._connect_timer._run() + await asyncio.sleep(0) + assert mock_start_connection.call_count == 1 + assert mock_finish_connection.call_count == 1 + assert rl._connection_state is ReconnectLogicState.HANDSHAKING + assert rl._accept_zeroconf_records is False + assert not rl._is_stopped + + caplog.clear() + # This can likely never happen in practice, but we should handle it + # in the event there is a race as the consequence is that we could + # disconnect a working connection. + rl._call_connect_once() + assert ( + "Not cancelling existing connect task as its already ReconnectLogicState.HANDSHAKING" + in caplog.text + ) + + rl._cancel_connect("forced cancel in test") await rl.stop() assert rl._is_stopped is True assert rl._connection_state is ReconnectLogicState.DISCONNECTED @@ -434,10 +576,6 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake( ) assert rl._connection_state is ReconnectLogicState.DISCONNECTED - async def slow_connect_fail(*args, **kwargs): - await asyncio.sleep(10) - raise APIConnectionError - with patch.object(cli, "start_connection"), patch.object( cli, "finish_connection", side_effect=slow_connect_fail ):