mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-13 10:43:59 +01:00
Fix reconnect logic cancelling the connection while handshaking (#726)
This commit is contained in:
parent
68dfc868d9
commit
b3d4189b07
@ -226,11 +226,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
def _schedule_connect(self, delay: float) -> None:
|
||||
"""Schedule a connect attempt."""
|
||||
self._cancel_connect("Scheduling new connect attempt")
|
||||
if not delay:
|
||||
self._call_connect_once()
|
||||
return
|
||||
_LOGGER.debug("Scheduling new connect attempt in %f seconds", delay)
|
||||
self._cancel_connect_timer()
|
||||
self._connect_timer = self.loop.call_at(
|
||||
self.loop.time() + delay, self._call_connect_once
|
||||
)
|
||||
@ -240,17 +240,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
Must only be called from _schedule_connect.
|
||||
"""
|
||||
if self._connect_task:
|
||||
if self._connect_task and not self._connect_task.done():
|
||||
if self._connection_state != ReconnectLogicState.CONNECTING:
|
||||
# Connection state is far enough along that we should
|
||||
# not restart the connect task
|
||||
_LOGGER.debug(
|
||||
"%s: Not cancelling existing connect task as its already %s!",
|
||||
self._cli.log_name,
|
||||
self._connection_state,
|
||||
)
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Cancelling existing connect task, to try again now!",
|
||||
"%s: Cancelling existing connect task with state %s, to try again now!",
|
||||
self._cli.log_name,
|
||||
self._connection_state,
|
||||
)
|
||||
self._connect_task.cancel("Scheduling new connect attempt")
|
||||
self._connect_task = None
|
||||
self._cancel_connect_task("Scheduling new connect attempt")
|
||||
self._async_set_connection_state_without_lock(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
@ -260,15 +265,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
name=f"{self._cli.log_name}: aioesphomeapi connect",
|
||||
)
|
||||
|
||||
def _cancel_connect(self, msg: str) -> None:
|
||||
"""Cancel the connect."""
|
||||
def _cancel_connect_timer(self) -> None:
|
||||
"""Cancel the connect timer."""
|
||||
if self._connect_timer:
|
||||
self._connect_timer.cancel()
|
||||
self._connect_timer = None
|
||||
|
||||
def _cancel_connect_task(self, msg: str) -> None:
|
||||
"""Cancel the connect task."""
|
||||
if self._connect_task:
|
||||
self._connect_task.cancel(msg)
|
||||
self._connect_task = None
|
||||
|
||||
def _cancel_connect(self, msg: str) -> None:
|
||||
"""Cancel the connect."""
|
||||
self._cancel_connect_timer()
|
||||
self._cancel_connect_task(msg)
|
||||
|
||||
async def _connect_once_or_reschedule(self) -> None:
|
||||
"""Connect once or schedule connect.
|
||||
|
||||
|
@ -38,6 +38,15 @@ from .conftest import _create_mock_transport_protocol
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
|
||||
async def quick_connect_fail(*args, **kwargs):
|
||||
raise APIConnectionError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_name_from_host():
|
||||
"""Test that the name is set correctly from the host."""
|
||||
@ -288,23 +297,33 @@ async def test_reconnect_retry(
|
||||
assert len(on_connect_called) == 1
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert rl._connection_state is ReconnectLogicState.READY
|
||||
original_when = rl._connect_timer.when()
|
||||
|
||||
# Ensure starting the connection logic again does not trigger a new connection
|
||||
await rl.start()
|
||||
# Verify no new timer is started
|
||||
assert rl._connect_timer.when() == original_when
|
||||
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
DNS_POINTER = DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"mydevice._esphomelib._tcp.local.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("record", "should_trigger_zeroconf", "log_text"),
|
||||
("record", "should_trigger_zeroconf", "expected_state_after_trigger", "log_text"),
|
||||
(
|
||||
(
|
||||
DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"mydevice._esphomelib._tcp.local.",
|
||||
),
|
||||
DNS_POINTER,
|
||||
True,
|
||||
ReconnectLogicState.READY,
|
||||
"received mDNS record",
|
||||
),
|
||||
(
|
||||
@ -316,6 +335,7 @@ async def test_reconnect_retry(
|
||||
"wrong_name._esphomelib._tcp.local.",
|
||||
),
|
||||
False,
|
||||
ReconnectLogicState.CONNECTING,
|
||||
"",
|
||||
),
|
||||
(
|
||||
@ -327,27 +347,23 @@ async def test_reconnect_retry(
|
||||
ip_address("1.2.3.4").packed,
|
||||
),
|
||||
True,
|
||||
ReconnectLogicState.READY,
|
||||
"received mDNS record",
|
||||
),
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_zeroconf(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
record: DNSRecord,
|
||||
should_trigger_zeroconf: bool,
|
||||
expected_state_after_trigger: ReconnectLogicState,
|
||||
log_text: str,
|
||||
) -> None:
|
||||
"""Test that reconnect logic retry."""
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
cli = patchable_api_client
|
||||
|
||||
mock_zeroconf = MagicMock(spec=Zeroconf)
|
||||
|
||||
@ -361,13 +377,6 @@ async def test_reconnect_zeroconf(
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
async def quick_connect_fail(*args, **kwargs):
|
||||
raise APIConnectionError
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
@ -379,17 +388,150 @@ async def test_reconnect_zeroconf(
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=slow_connect_fail
|
||||
) as mock_start_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.CONNECTING
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert mock_start_connection.call_count == 0
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection"
|
||||
):
|
||||
assert rl._zc_listening is True
|
||||
rl.async_update_records(
|
||||
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
|
||||
)
|
||||
assert (
|
||||
"Triggering connect because of received mDNS record" in caplog.text
|
||||
) is should_trigger_zeroconf
|
||||
assert rl._zc_listening is True # should change after one iteration of the loop
|
||||
await asyncio.sleep(0)
|
||||
assert rl._zc_listening is not should_trigger_zeroconf
|
||||
|
||||
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
|
||||
assert log_text in caplog.text
|
||||
|
||||
assert rl._connection_state is expected_state_after_trigger
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_zeroconf_not_while_handshaking(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that reconnect logic retry will not trigger a zeroconf reconnect while handshaking."""
|
||||
cli = patchable_api_client
|
||||
|
||||
mock_zeroconf = MagicMock(spec=Zeroconf)
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
zeroconf_instance=mock_zeroconf,
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
) as mock_finish_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert mock_finish_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
|
||||
assert rl._accept_zeroconf_records is False
|
||||
assert not rl._is_stopped
|
||||
|
||||
rl.async_update_records(
|
||||
mock_zeroconf, current_time_millis(), [RecordUpdate(DNS_POINTER, None)]
|
||||
)
|
||||
assert (
|
||||
"Triggering connect because of received mDNS record" in caplog.text
|
||||
) is False
|
||||
|
||||
rl._cancel_connect("forced cancel in test")
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_task_not_cancelled_while_handshaking(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that reconnect logic will not cancel an in progress handshake."""
|
||||
cli = patchable_api_client
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
) as mock_finish_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert mock_finish_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
|
||||
assert rl._accept_zeroconf_records is False
|
||||
assert not rl._is_stopped
|
||||
|
||||
caplog.clear()
|
||||
# This can likely never happen in practice, but we should handle it
|
||||
# in the event there is a race as the consequence is that we could
|
||||
# disconnect a working connection.
|
||||
rl._call_connect_once()
|
||||
assert (
|
||||
"Not cancelling existing connect task as its already ReconnectLogicState.HANDSHAKING"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
rl._cancel_connect("forced cancel in test")
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
@ -434,10 +576,6 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(
|
||||
)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user