Fix reconnect logic cancelling the connection while handshaking (#726)

This commit is contained in:
J. Nick Koston 2023-11-26 09:14:42 -06:00 committed by GitHub
parent 68dfc868d9
commit b3d4189b07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 187 additions and 36 deletions

View File

@ -226,11 +226,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
def _schedule_connect(self, delay: float) -> None:
"""Schedule a connect attempt."""
self._cancel_connect("Scheduling new connect attempt")
if not delay:
self._call_connect_once()
return
_LOGGER.debug("Scheduling new connect attempt in %f seconds", delay)
self._cancel_connect_timer()
self._connect_timer = self.loop.call_at(
self.loop.time() + delay, self._call_connect_once
)
@ -240,17 +240,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
Must only be called from _schedule_connect.
"""
if self._connect_task:
if self._connect_task and not self._connect_task.done():
if self._connection_state != ReconnectLogicState.CONNECTING:
# Connection state is far enough along that we should
# not restart the connect task
_LOGGER.debug(
"%s: Not cancelling existing connect task as its already %s!",
self._cli.log_name,
self._connection_state,
)
return
_LOGGER.debug(
"%s: Cancelling existing connect task, to try again now!",
"%s: Cancelling existing connect task with state %s, to try again now!",
self._cli.log_name,
self._connection_state,
)
self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None
self._cancel_connect_task("Scheduling new connect attempt")
self._async_set_connection_state_without_lock(
ReconnectLogicState.DISCONNECTED
)
@ -260,15 +265,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
name=f"{self._cli.log_name}: aioesphomeapi connect",
)
def _cancel_connect(self, msg: str) -> None:
"""Cancel the connect."""
def _cancel_connect_timer(self) -> None:
"""Cancel the connect timer."""
if self._connect_timer:
self._connect_timer.cancel()
self._connect_timer = None
def _cancel_connect_task(self, msg: str) -> None:
"""Cancel the connect task."""
if self._connect_task:
self._connect_task.cancel(msg)
self._connect_task = None
def _cancel_connect(self, msg: str) -> None:
"""Cancel the connect."""
self._cancel_connect_timer()
self._cancel_connect_task(msg)
async def _connect_once_or_reschedule(self) -> None:
"""Connect once or schedule connect.

View File

@ -38,6 +38,15 @@ from .conftest import _create_mock_transport_protocol
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
async def quick_connect_fail(*args, **kwargs):
raise APIConnectionError
@pytest.mark.asyncio
async def test_reconnect_logic_name_from_host():
"""Test that the name is set correctly from the host."""
@ -288,23 +297,33 @@ async def test_reconnect_retry(
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 2
assert rl._connection_state is ReconnectLogicState.READY
original_when = rl._connect_timer.when()
# Ensure starting the connection logic again does not trigger a new connection
await rl.start()
# Verify no new timer is started
assert rl._connect_timer.when() == original_when
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
DNS_POINTER = DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"mydevice._esphomelib._tcp.local.",
)
@pytest.mark.parametrize(
("record", "should_trigger_zeroconf", "log_text"),
("record", "should_trigger_zeroconf", "expected_state_after_trigger", "log_text"),
(
(
DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"mydevice._esphomelib._tcp.local.",
),
DNS_POINTER,
True,
ReconnectLogicState.READY,
"received mDNS record",
),
(
@ -316,6 +335,7 @@ async def test_reconnect_retry(
"wrong_name._esphomelib._tcp.local.",
),
False,
ReconnectLogicState.CONNECTING,
"",
),
(
@ -327,27 +347,23 @@ async def test_reconnect_retry(
ip_address("1.2.3.4").packed,
),
True,
ReconnectLogicState.READY,
"received mDNS record",
),
),
)
@pytest.mark.asyncio
async def test_reconnect_zeroconf(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
record: DNSRecord,
should_trigger_zeroconf: bool,
expected_state_after_trigger: ReconnectLogicState,
log_text: str,
) -> None:
"""Test that reconnect logic retry."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
mock_zeroconf = MagicMock(spec=Zeroconf)
@ -361,13 +377,6 @@ async def test_reconnect_zeroconf(
)
assert cli.log_name == "mydevice @ 1.2.3.4"
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
async def quick_connect_fail(*args, **kwargs):
raise APIConnectionError
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
@ -379,17 +388,150 @@ async def test_reconnect_zeroconf(
with patch.object(
cli, "start_connection", side_effect=slow_connect_fail
) as mock_start_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.CONNECTING
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert mock_start_connection.call_count == 0
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection"
):
assert rl._zc_listening is True
rl.async_update_records(
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
)
assert (
"Triggering connect because of received mDNS record" in caplog.text
) is should_trigger_zeroconf
assert rl._zc_listening is True # should change after one iteration of the loop
await asyncio.sleep(0)
assert rl._zc_listening is not should_trigger_zeroconf
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
assert log_text in caplog.text
assert rl._connection_state is expected_state_after_trigger
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_zeroconf_not_while_handshaking(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that reconnect logic retry will not trigger a zeroconf reconnect while handshaking."""
cli = patchable_api_client
mock_zeroconf = MagicMock(spec=Zeroconf)
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
zeroconf_instance=mock_zeroconf,
name="mydevice",
on_connect_error=AsyncMock(),
)
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
) as mock_finish_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert mock_finish_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
assert rl._accept_zeroconf_records is False
assert not rl._is_stopped
rl.async_update_records(
mock_zeroconf, current_time_millis(), [RecordUpdate(DNS_POINTER, None)]
)
assert (
"Triggering connect because of received mDNS record" in caplog.text
) is False
rl._cancel_connect("forced cancel in test")
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_connect_task_not_cancelled_while_handshaking(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that reconnect logic will not cancel an in progress handshake."""
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
name="mydevice",
on_connect_error=AsyncMock(),
)
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
) as mock_finish_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert mock_finish_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
assert rl._accept_zeroconf_records is False
assert not rl._is_stopped
caplog.clear()
# This can likely never happen in practice, but we should handle it
# in the event there is a race as the consequence is that we could
# disconnect a working connection.
rl._call_connect_once()
assert (
"Not cancelling existing connect task as its already ReconnectLogicState.HANDSHAKING"
in caplog.text
)
rl._cancel_connect("forced cancel in test")
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@ -434,10 +576,6 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(
)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
):