Make reconnect logic state machine switches check locks (#597)

This commit is contained in:
J. Nick Koston 2023-10-21 17:46:45 -10:00 committed by GitHub
parent 4122dede82
commit 000ff14ac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 198 additions and 19 deletions

View File

@ -90,6 +90,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._filter_alias: str | None = None
# Flag to check if the device is connected
self._connection_state = ReconnectLogicState.DISCONNECTED
self._accept_zeroconf_records = True
self._connected_lock = asyncio.Lock()
self._is_stopped = True
self._zc_listening = False
@ -118,8 +119,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Run disconnect hook
await self._on_disconnect_cb(expected_disconnect)
async with self._connected_lock:
self._connection_state = ReconnectLogicState.DISCONNECTED
await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
# If we expected the disconnect we need
@ -128,6 +128,29 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# before its about to reboot in the event we are too fast.
self._schedule_connect(wait)
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
"""Set the connection state."""
async with self._connected_lock:
self._async_set_connection_state_while_locked(state)
def _async_set_connection_state_while_locked(
self, state: ReconnectLogicState
) -> None:
"""Set the connection state while holding the lock."""
assert self._connected_lock.locked(), "connected_lock must be locked"
self._async_set_connection_state_without_lock(state)
def _async_set_connection_state_without_lock(
self, state: ReconnectLogicState
) -> None:
"""Set the connection state without holding the lock.
This should only be used for setting the state to DISCONNECTED
when the state is CONNECTING.
"""
self._connection_state = state
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
def _async_log_connection_error(self, err: Exception) -> None:
"""Log connection errors."""
# UnhandledAPIConnectionError is a special case in client
@ -155,12 +178,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async def _try_connect(self) -> bool:
"""Try connecting to the API client."""
assert self._connected_lock.locked(), "connected_lock must be locked"
self._connection_state = ReconnectLogicState.CONNECTING
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
try:
await self._cli.start_connection(on_stop=self._on_disconnect)
except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
@ -168,11 +192,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return False
_LOGGER.info("Successfully connected to %s", self._log_name)
self._stop_zc_listen()
self._connection_state = ReconnectLogicState.HANDSHAKING
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
try:
await self._cli.finish_connection(login=True)
except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
@ -185,7 +211,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return False
self._tries = 0
_LOGGER.info("Successful handshake with %s", self._log_name)
self._connection_state = ReconnectLogicState.READY
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
await self._on_connect_cb()
return True
@ -216,7 +242,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
)
self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None
self._connection_state = ReconnectLogicState.DISCONNECTED
self._async_set_connection_state_without_lock(
ReconnectLogicState.DISCONNECTED
)
self._connect_task = asyncio.create_task(
self._connect_once_or_reschedule(),
@ -292,6 +320,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Cancel again while holding the lock
self._cancel_connect("Stopping")
self._stop_zc_listen()
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
def _start_zc_listen(self) -> None:
"""Listen for mDNS records.
@ -325,7 +356,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Check if already connected, no lock needed for this access and
# bail if either the already stopped or we haven't received device info yet
if (
self._connection_state not in NOT_YET_CONNECTED_STATES
not self._accept_zeroconf_records
or self._is_stopped
or self._filter_alias is None
):
@ -334,10 +365,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
for record_update in records:
# We only consider PTR records and match using the alias name
new_record = record_update.new
if (
new_record.type != TYPE_PTR
or new_record.alias != self._filter_alias # type: ignore[attr-defined]
):
if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
continue
# Tell connection logic to retry connection attempt now (even before connect timer finishes)

View File

@ -1,10 +1,17 @@
from unittest.mock import MagicMock
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi import APIConnectionError
from aioesphomeapi.client import APIClient
from aioesphomeapi.reconnect_logic import ReconnectLogic
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
def _get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
@pytest.mark.asyncio
@ -51,7 +58,7 @@ async def test_reconnect_logic_name_from_host_and_set():
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
)
assert rl._log_name == "mydevice"
@ -77,7 +84,7 @@ async def test_reconnect_logic_name_from_address():
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
zeroconf_instance=_get_mock_zeroconf(),
)
assert rl._log_name == "1.2.3.4"
assert cli._log_name == "1.2.3.4"
@ -102,8 +109,152 @@ async def test_reconnect_logic_name_from_name():
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
@pytest.mark.asyncio
async def test_reconnect_logic_state():
"""Test that reconnect logic state changes."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
on_disconnect_called.append(expected_disconnect)
async def on_connect() -> None:
nonlocal on_connect_called
on_connect_called.append(True)
async def on_connect_fail(connect_exception: Exception) -> None:
nonlocal on_connect_called
on_connect_fail_called.append(connect_exception)
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=APIConnectionError
):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 2
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 2
assert rl._connection_state is ReconnectLogicState.READY
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_retry():
"""Test that reconnect logic retry."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
on_disconnect_called.append(expected_disconnect)
async def on_connect() -> None:
nonlocal on_connect_called
on_connect_called.append(True)
async def on_connect_fail(connect_exception: Exception) -> None:
nonlocal on_connect_called
on_connect_fail_called.append(connect_exception)
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
# Should now retry
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 1
assert rl._connection_state is ReconnectLogicState.READY
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED