Make reconnect logic state machine switches check locks (#597)

This commit is contained in:
J. Nick Koston 2023-10-21 17:46:45 -10:00 committed by GitHub
parent 4122dede82
commit 000ff14ac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 198 additions and 19 deletions

View File

@ -90,6 +90,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._filter_alias: str | None = None self._filter_alias: str | None = None
# Flag to check if the device is connected # Flag to check if the device is connected
self._connection_state = ReconnectLogicState.DISCONNECTED self._connection_state = ReconnectLogicState.DISCONNECTED
self._accept_zeroconf_records = True
self._connected_lock = asyncio.Lock() self._connected_lock = asyncio.Lock()
self._is_stopped = True self._is_stopped = True
self._zc_listening = False self._zc_listening = False
@ -118,8 +119,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Run disconnect hook # Run disconnect hook
await self._on_disconnect_cb(expected_disconnect) await self._on_disconnect_cb(expected_disconnect)
async with self._connected_lock: await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
self._connection_state = ReconnectLogicState.DISCONNECTED
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0 wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
# If we expected the disconnect we need # If we expected the disconnect we need
@ -128,6 +128,29 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# before its about to reboot in the event we are too fast. # before its about to reboot in the event we are too fast.
self._schedule_connect(wait) self._schedule_connect(wait)
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
"""Set the connection state."""
async with self._connected_lock:
self._async_set_connection_state_while_locked(state)
def _async_set_connection_state_while_locked(
self, state: ReconnectLogicState
) -> None:
"""Set the connection state while holding the lock."""
assert self._connected_lock.locked(), "connected_lock must be locked"
self._async_set_connection_state_without_lock(state)
def _async_set_connection_state_without_lock(
self, state: ReconnectLogicState
) -> None:
"""Set the connection state without holding the lock.
This should only be used for setting the state to DISCONNECTED
when the state is CONNECTING.
"""
self._connection_state = state
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
def _async_log_connection_error(self, err: Exception) -> None: def _async_log_connection_error(self, err: Exception) -> None:
"""Log connection errors.""" """Log connection errors."""
# UnhandledAPIConnectionError is a special case in client # UnhandledAPIConnectionError is a special case in client
@ -155,12 +178,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async def _try_connect(self) -> bool: async def _try_connect(self) -> bool:
"""Try connecting to the API client.""" """Try connecting to the API client."""
assert self._connected_lock.locked(), "connected_lock must be locked" self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
self._connection_state = ReconnectLogicState.CONNECTING
try: try:
await self._cli.start_connection(on_stop=self._on_disconnect) await self._cli.start_connection(on_stop=self._on_disconnect)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None: if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err) await self._on_connect_error_cb(err)
self._async_log_connection_error(err) self._async_log_connection_error(err)
@ -168,11 +192,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return False return False
_LOGGER.info("Successfully connected to %s", self._log_name) _LOGGER.info("Successfully connected to %s", self._log_name)
self._stop_zc_listen() self._stop_zc_listen()
self._connection_state = ReconnectLogicState.HANDSHAKING self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
try: try:
await self._cli.finish_connection(login=True) await self._cli.finish_connection(login=True)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None: if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err) await self._on_connect_error_cb(err)
self._async_log_connection_error(err) self._async_log_connection_error(err)
@ -185,7 +211,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
return False return False
self._tries = 0 self._tries = 0
_LOGGER.info("Successful handshake with %s", self._log_name) _LOGGER.info("Successful handshake with %s", self._log_name)
self._connection_state = ReconnectLogicState.READY self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
await self._on_connect_cb() await self._on_connect_cb()
return True return True
@ -216,7 +242,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
) )
self._connect_task.cancel("Scheduling new connect attempt") self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None self._connect_task = None
self._connection_state = ReconnectLogicState.DISCONNECTED self._async_set_connection_state_without_lock(
ReconnectLogicState.DISCONNECTED
)
self._connect_task = asyncio.create_task( self._connect_task = asyncio.create_task(
self._connect_once_or_reschedule(), self._connect_once_or_reschedule(),
@ -292,6 +320,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Cancel again while holding the lock # Cancel again while holding the lock
self._cancel_connect("Stopping") self._cancel_connect("Stopping")
self._stop_zc_listen() self._stop_zc_listen()
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
def _start_zc_listen(self) -> None: def _start_zc_listen(self) -> None:
"""Listen for mDNS records. """Listen for mDNS records.
@ -325,7 +356,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# Check if already connected, no lock needed for this access and # Check if already connected, no lock needed for this access and
# bail if either the already stopped or we haven't received device info yet # bail if either the already stopped or we haven't received device info yet
if ( if (
self._connection_state not in NOT_YET_CONNECTED_STATES not self._accept_zeroconf_records
or self._is_stopped or self._is_stopped
or self._filter_alias is None or self._filter_alias is None
): ):
@ -334,10 +365,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
for record_update in records: for record_update in records:
# We only consider PTR records and match using the alias name # We only consider PTR records and match using the alias name
new_record = record_update.new new_record = record_update.new
if ( if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
new_record.type != TYPE_PTR
or new_record.alias != self._filter_alias # type: ignore[attr-defined]
):
continue continue
# Tell connection logic to retry connection attempt now (even before connect timer finishes) # Tell connection logic to retry connection attempt now (even before connect timer finishes)

View File

@ -1,10 +1,17 @@
from unittest.mock import MagicMock import asyncio
from unittest.mock import MagicMock, patch
import pytest import pytest
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi import APIConnectionError
from aioesphomeapi.client import APIClient from aioesphomeapi.client import APIClient
from aioesphomeapi.reconnect_logic import ReconnectLogic from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
def _get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -51,7 +58,7 @@ async def test_reconnect_logic_name_from_host_and_set():
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf), zeroconf_instance=_get_mock_zeroconf(),
name="mydevice", name="mydevice",
) )
assert rl._log_name == "mydevice" assert rl._log_name == "mydevice"
@ -77,7 +84,7 @@ async def test_reconnect_logic_name_from_address():
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf), zeroconf_instance=_get_mock_zeroconf(),
) )
assert rl._log_name == "1.2.3.4" assert rl._log_name == "1.2.3.4"
assert cli._log_name == "1.2.3.4" assert cli._log_name == "1.2.3.4"
@ -102,8 +109,152 @@ async def test_reconnect_logic_name_from_name():
client=cli, client=cli,
on_disconnect=on_disconnect, on_disconnect=on_disconnect,
on_connect=on_connect, on_connect=on_connect,
zeroconf_instance=MagicMock(spec=AsyncZeroconf), zeroconf_instance=_get_mock_zeroconf(),
name="mydevice", name="mydevice",
) )
assert rl._log_name == "mydevice @ 1.2.3.4" assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4" assert cli._log_name == "mydevice @ 1.2.3.4"
@pytest.mark.asyncio
async def test_reconnect_logic_state():
"""Test that reconnect logic state changes."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
on_disconnect_called.append(expected_disconnect)
async def on_connect() -> None:
nonlocal on_connect_called
on_connect_called.append(True)
async def on_connect_fail(connect_exception: Exception) -> None:
nonlocal on_connect_called
on_connect_fail_called.append(connect_exception)
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=APIConnectionError
):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 2
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 2
assert rl._connection_state is ReconnectLogicState.READY
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_retry():
"""Test that reconnect logic retry."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
on_disconnect_called.append(expected_disconnect)
async def on_connect() -> None:
nonlocal on_connect_called
on_connect_called.append(True)
async def on_connect_fail(connect_exception: Exception) -> None:
nonlocal on_connect_called
on_connect_fail_called.append(connect_exception)
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=_get_mock_zeroconf(),
name="mydevice",
on_connect_error=on_connect_fail,
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
# Should now retry
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 1
assert rl._connection_state is ReconnectLogicState.READY
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED