mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
Make reconnect logic state machine switches check locks (#597)
This commit is contained in:
parent
4122dede82
commit
000ff14ac0
@ -90,6 +90,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._filter_alias: str | None = None
|
||||
# Flag to check if the device is connected
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._accept_zeroconf_records = True
|
||||
self._connected_lock = asyncio.Lock()
|
||||
self._is_stopped = True
|
||||
self._zc_listening = False
|
||||
@ -118,8 +119,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Run disconnect hook
|
||||
await self._on_disconnect_cb(expected_disconnect)
|
||||
|
||||
async with self._connected_lock:
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
|
||||
|
||||
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
||||
# If we expected the disconnect we need
|
||||
@ -128,6 +128,29 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# before its about to reboot in the event we are too fast.
|
||||
self._schedule_connect(wait)
|
||||
|
||||
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
|
||||
"""Set the connection state."""
|
||||
async with self._connected_lock:
|
||||
self._async_set_connection_state_while_locked(state)
|
||||
|
||||
def _async_set_connection_state_while_locked(
|
||||
self, state: ReconnectLogicState
|
||||
) -> None:
|
||||
"""Set the connection state while holding the lock."""
|
||||
assert self._connected_lock.locked(), "connected_lock must be locked"
|
||||
self._async_set_connection_state_without_lock(state)
|
||||
|
||||
def _async_set_connection_state_without_lock(
|
||||
self, state: ReconnectLogicState
|
||||
) -> None:
|
||||
"""Set the connection state without holding the lock.
|
||||
|
||||
This should only be used for setting the state to DISCONNECTED
|
||||
when the state is CONNECTING.
|
||||
"""
|
||||
self._connection_state = state
|
||||
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
|
||||
|
||||
def _async_log_connection_error(self, err: Exception) -> None:
|
||||
"""Log connection errors."""
|
||||
# UnhandledAPIConnectionError is a special case in client
|
||||
@ -155,12 +178,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
async def _try_connect(self) -> bool:
|
||||
"""Try connecting to the API client."""
|
||||
assert self._connected_lock.locked(), "connected_lock must be locked"
|
||||
self._connection_state = ReconnectLogicState.CONNECTING
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
|
||||
try:
|
||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._async_set_connection_state_while_locked(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
@ -168,11 +192,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
self._stop_zc_listen()
|
||||
self._connection_state = ReconnectLogicState.HANDSHAKING
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
|
||||
try:
|
||||
await self._cli.finish_connection(login=True)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._async_set_connection_state_while_locked(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
@ -185,7 +211,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
return False
|
||||
self._tries = 0
|
||||
_LOGGER.info("Successful handshake with %s", self._log_name)
|
||||
self._connection_state = ReconnectLogicState.READY
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
|
||||
await self._on_connect_cb()
|
||||
return True
|
||||
|
||||
@ -216,7 +242,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
)
|
||||
self._connect_task.cancel("Scheduling new connect attempt")
|
||||
self._connect_task = None
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._async_set_connection_state_without_lock(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._connect_once_or_reschedule(),
|
||||
@ -292,6 +320,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Cancel again while holding the lock
|
||||
self._cancel_connect("Stopping")
|
||||
self._stop_zc_listen()
|
||||
self._async_set_connection_state_while_locked(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
|
||||
def _start_zc_listen(self) -> None:
|
||||
"""Listen for mDNS records.
|
||||
@ -325,7 +356,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Check if already connected, no lock needed for this access and
|
||||
# bail if either the already stopped or we haven't received device info yet
|
||||
if (
|
||||
self._connection_state not in NOT_YET_CONNECTED_STATES
|
||||
not self._accept_zeroconf_records
|
||||
or self._is_stopped
|
||||
or self._filter_alias is None
|
||||
):
|
||||
@ -334,10 +365,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
for record_update in records:
|
||||
# We only consider PTR records and match using the alias name
|
||||
new_record = record_update.new
|
||||
if (
|
||||
new_record.type != TYPE_PTR
|
||||
or new_record.alias != self._filter_alias # type: ignore[attr-defined]
|
||||
):
|
||||
if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
|
||||
continue
|
||||
|
||||
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
||||
|
@ -1,10 +1,17 @@
|
||||
from unittest.mock import MagicMock
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from zeroconf import Zeroconf
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
|
||||
def _get_mock_zeroconf() -> MagicMock:
|
||||
return MagicMock(spec=Zeroconf)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -51,7 +58,7 @@ async def test_reconnect_logic_name_from_host_and_set():
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
||||
zeroconf_instance=_get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
)
|
||||
assert rl._log_name == "mydevice"
|
||||
@ -77,7 +84,7 @@ async def test_reconnect_logic_name_from_address():
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
||||
zeroconf_instance=_get_mock_zeroconf(),
|
||||
)
|
||||
assert rl._log_name == "1.2.3.4"
|
||||
assert cli._log_name == "1.2.3.4"
|
||||
@ -102,8 +109,152 @@ async def test_reconnect_logic_name_from_name():
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
||||
zeroconf_instance=_get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_state():
|
||||
"""Test that reconnect logic state changes."""
|
||||
on_disconnect_called = []
|
||||
on_connect_called = []
|
||||
on_connect_fail_called = []
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
nonlocal on_disconnect_called
|
||||
on_disconnect_called.append(expected_disconnect)
|
||||
|
||||
async def on_connect() -> None:
|
||||
nonlocal on_connect_called
|
||||
on_connect_called.append(True)
|
||||
|
||||
async def on_connect_fail(connect_exception: Exception) -> None:
|
||||
nonlocal on_connect_called
|
||||
on_connect_fail_called.append(connect_exception)
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=_get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
on_connect_error=on_connect_fail,
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 0
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(
|
||||
cli, "finish_connection", side_effect=APIConnectionError
|
||||
):
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 0
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 1
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert rl._connection_state is ReconnectLogicState.READY
|
||||
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_retry():
|
||||
"""Test that reconnect logic retry."""
|
||||
on_disconnect_called = []
|
||||
on_connect_called = []
|
||||
on_connect_fail_called = []
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
nonlocal on_disconnect_called
|
||||
on_disconnect_called.append(expected_disconnect)
|
||||
|
||||
async def on_connect() -> None:
|
||||
nonlocal on_connect_called
|
||||
on_connect_called.append(True)
|
||||
|
||||
async def on_connect_fail(connect_exception: Exception) -> None:
|
||||
nonlocal on_connect_called
|
||||
on_connect_fail_called.append(connect_exception)
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=_get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
on_connect_error=on_connect_fail,
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 0
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||
# Should now retry
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 1
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert rl._connection_state is ReconnectLogicState.READY
|
||||
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
Loading…
Reference in New Issue
Block a user