mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
Make reconnect logic state machine switches check locks (#597)
This commit is contained in:
parent
4122dede82
commit
000ff14ac0
@ -90,6 +90,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
self._filter_alias: str | None = None
|
self._filter_alias: str | None = None
|
||||||
# Flag to check if the device is connected
|
# Flag to check if the device is connected
|
||||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||||
|
self._accept_zeroconf_records = True
|
||||||
self._connected_lock = asyncio.Lock()
|
self._connected_lock = asyncio.Lock()
|
||||||
self._is_stopped = True
|
self._is_stopped = True
|
||||||
self._zc_listening = False
|
self._zc_listening = False
|
||||||
@ -118,8 +119,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# Run disconnect hook
|
# Run disconnect hook
|
||||||
await self._on_disconnect_cb(expected_disconnect)
|
await self._on_disconnect_cb(expected_disconnect)
|
||||||
|
|
||||||
async with self._connected_lock:
|
await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
|
||||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
|
||||||
|
|
||||||
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
||||||
# If we expected the disconnect we need
|
# If we expected the disconnect we need
|
||||||
@ -128,6 +128,29 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# before its about to reboot in the event we are too fast.
|
# before its about to reboot in the event we are too fast.
|
||||||
self._schedule_connect(wait)
|
self._schedule_connect(wait)
|
||||||
|
|
||||||
|
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
|
||||||
|
"""Set the connection state."""
|
||||||
|
async with self._connected_lock:
|
||||||
|
self._async_set_connection_state_while_locked(state)
|
||||||
|
|
||||||
|
def _async_set_connection_state_while_locked(
|
||||||
|
self, state: ReconnectLogicState
|
||||||
|
) -> None:
|
||||||
|
"""Set the connection state while holding the lock."""
|
||||||
|
assert self._connected_lock.locked(), "connected_lock must be locked"
|
||||||
|
self._async_set_connection_state_without_lock(state)
|
||||||
|
|
||||||
|
def _async_set_connection_state_without_lock(
|
||||||
|
self, state: ReconnectLogicState
|
||||||
|
) -> None:
|
||||||
|
"""Set the connection state without holding the lock.
|
||||||
|
|
||||||
|
This should only be used for setting the state to DISCONNECTED
|
||||||
|
when the state is CONNECTING.
|
||||||
|
"""
|
||||||
|
self._connection_state = state
|
||||||
|
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
|
||||||
|
|
||||||
def _async_log_connection_error(self, err: Exception) -> None:
|
def _async_log_connection_error(self, err: Exception) -> None:
|
||||||
"""Log connection errors."""
|
"""Log connection errors."""
|
||||||
# UnhandledAPIConnectionError is a special case in client
|
# UnhandledAPIConnectionError is a special case in client
|
||||||
@ -155,12 +178,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
|
|
||||||
async def _try_connect(self) -> bool:
|
async def _try_connect(self) -> bool:
|
||||||
"""Try connecting to the API client."""
|
"""Try connecting to the API client."""
|
||||||
assert self._connected_lock.locked(), "connected_lock must be locked"
|
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
|
||||||
self._connection_state = ReconnectLogicState.CONNECTING
|
|
||||||
try:
|
try:
|
||||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
self._async_set_connection_state_while_locked(
|
||||||
|
ReconnectLogicState.DISCONNECTED
|
||||||
|
)
|
||||||
if self._on_connect_error_cb is not None:
|
if self._on_connect_error_cb is not None:
|
||||||
await self._on_connect_error_cb(err)
|
await self._on_connect_error_cb(err)
|
||||||
self._async_log_connection_error(err)
|
self._async_log_connection_error(err)
|
||||||
@ -168,11 +192,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
return False
|
return False
|
||||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||||
self._stop_zc_listen()
|
self._stop_zc_listen()
|
||||||
self._connection_state = ReconnectLogicState.HANDSHAKING
|
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
|
||||||
try:
|
try:
|
||||||
await self._cli.finish_connection(login=True)
|
await self._cli.finish_connection(login=True)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
self._async_set_connection_state_while_locked(
|
||||||
|
ReconnectLogicState.DISCONNECTED
|
||||||
|
)
|
||||||
if self._on_connect_error_cb is not None:
|
if self._on_connect_error_cb is not None:
|
||||||
await self._on_connect_error_cb(err)
|
await self._on_connect_error_cb(err)
|
||||||
self._async_log_connection_error(err)
|
self._async_log_connection_error(err)
|
||||||
@ -185,7 +211,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
return False
|
return False
|
||||||
self._tries = 0
|
self._tries = 0
|
||||||
_LOGGER.info("Successful handshake with %s", self._log_name)
|
_LOGGER.info("Successful handshake with %s", self._log_name)
|
||||||
self._connection_state = ReconnectLogicState.READY
|
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
|
||||||
await self._on_connect_cb()
|
await self._on_connect_cb()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -216,7 +242,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
)
|
)
|
||||||
self._connect_task.cancel("Scheduling new connect attempt")
|
self._connect_task.cancel("Scheduling new connect attempt")
|
||||||
self._connect_task = None
|
self._connect_task = None
|
||||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
self._async_set_connection_state_without_lock(
|
||||||
|
ReconnectLogicState.DISCONNECTED
|
||||||
|
)
|
||||||
|
|
||||||
self._connect_task = asyncio.create_task(
|
self._connect_task = asyncio.create_task(
|
||||||
self._connect_once_or_reschedule(),
|
self._connect_once_or_reschedule(),
|
||||||
@ -292,6 +320,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# Cancel again while holding the lock
|
# Cancel again while holding the lock
|
||||||
self._cancel_connect("Stopping")
|
self._cancel_connect("Stopping")
|
||||||
self._stop_zc_listen()
|
self._stop_zc_listen()
|
||||||
|
self._async_set_connection_state_while_locked(
|
||||||
|
ReconnectLogicState.DISCONNECTED
|
||||||
|
)
|
||||||
|
|
||||||
def _start_zc_listen(self) -> None:
|
def _start_zc_listen(self) -> None:
|
||||||
"""Listen for mDNS records.
|
"""Listen for mDNS records.
|
||||||
@ -325,7 +356,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
# Check if already connected, no lock needed for this access and
|
# Check if already connected, no lock needed for this access and
|
||||||
# bail if either the already stopped or we haven't received device info yet
|
# bail if either the already stopped or we haven't received device info yet
|
||||||
if (
|
if (
|
||||||
self._connection_state not in NOT_YET_CONNECTED_STATES
|
not self._accept_zeroconf_records
|
||||||
or self._is_stopped
|
or self._is_stopped
|
||||||
or self._filter_alias is None
|
or self._filter_alias is None
|
||||||
):
|
):
|
||||||
@ -334,10 +365,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
for record_update in records:
|
for record_update in records:
|
||||||
# We only consider PTR records and match using the alias name
|
# We only consider PTR records and match using the alias name
|
||||||
new_record = record_update.new
|
new_record = record_update.new
|
||||||
if (
|
if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
|
||||||
new_record.type != TYPE_PTR
|
|
||||||
or new_record.alias != self._filter_alias # type: ignore[attr-defined]
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
from unittest.mock import MagicMock
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from zeroconf import Zeroconf
|
||||||
from zeroconf.asyncio import AsyncZeroconf
|
from zeroconf.asyncio import AsyncZeroconf
|
||||||
|
|
||||||
|
from aioesphomeapi import APIConnectionError
|
||||||
from aioesphomeapi.client import APIClient
|
from aioesphomeapi.client import APIClient
|
||||||
from aioesphomeapi.reconnect_logic import ReconnectLogic
|
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_zeroconf() -> MagicMock:
|
||||||
|
return MagicMock(spec=Zeroconf)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -51,7 +58,7 @@ async def test_reconnect_logic_name_from_host_and_set():
|
|||||||
client=cli,
|
client=cli,
|
||||||
on_disconnect=on_disconnect,
|
on_disconnect=on_disconnect,
|
||||||
on_connect=on_connect,
|
on_connect=on_connect,
|
||||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
zeroconf_instance=_get_mock_zeroconf(),
|
||||||
name="mydevice",
|
name="mydevice",
|
||||||
)
|
)
|
||||||
assert rl._log_name == "mydevice"
|
assert rl._log_name == "mydevice"
|
||||||
@ -77,7 +84,7 @@ async def test_reconnect_logic_name_from_address():
|
|||||||
client=cli,
|
client=cli,
|
||||||
on_disconnect=on_disconnect,
|
on_disconnect=on_disconnect,
|
||||||
on_connect=on_connect,
|
on_connect=on_connect,
|
||||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
zeroconf_instance=_get_mock_zeroconf(),
|
||||||
)
|
)
|
||||||
assert rl._log_name == "1.2.3.4"
|
assert rl._log_name == "1.2.3.4"
|
||||||
assert cli._log_name == "1.2.3.4"
|
assert cli._log_name == "1.2.3.4"
|
||||||
@ -102,8 +109,152 @@ async def test_reconnect_logic_name_from_name():
|
|||||||
client=cli,
|
client=cli,
|
||||||
on_disconnect=on_disconnect,
|
on_disconnect=on_disconnect,
|
||||||
on_connect=on_connect,
|
on_connect=on_connect,
|
||||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
zeroconf_instance=_get_mock_zeroconf(),
|
||||||
name="mydevice",
|
name="mydevice",
|
||||||
)
|
)
|
||||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_logic_state():
|
||||||
|
"""Test that reconnect logic state changes."""
|
||||||
|
on_disconnect_called = []
|
||||||
|
on_connect_called = []
|
||||||
|
on_connect_fail_called = []
|
||||||
|
|
||||||
|
class PatchableAPIClient(APIClient):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cli = PatchableAPIClient(
|
||||||
|
address="1.2.3.4",
|
||||||
|
port=6052,
|
||||||
|
password=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||||
|
nonlocal on_disconnect_called
|
||||||
|
on_disconnect_called.append(expected_disconnect)
|
||||||
|
|
||||||
|
async def on_connect() -> None:
|
||||||
|
nonlocal on_connect_called
|
||||||
|
on_connect_called.append(True)
|
||||||
|
|
||||||
|
async def on_connect_fail(connect_exception: Exception) -> None:
|
||||||
|
nonlocal on_connect_called
|
||||||
|
on_connect_fail_called.append(connect_exception)
|
||||||
|
|
||||||
|
rl = ReconnectLogic(
|
||||||
|
client=cli,
|
||||||
|
on_disconnect=on_disconnect,
|
||||||
|
on_connect=on_connect,
|
||||||
|
zeroconf_instance=_get_mock_zeroconf(),
|
||||||
|
name="mydevice",
|
||||||
|
on_connect_error=on_connect_fail,
|
||||||
|
)
|
||||||
|
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||||
|
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||||
|
|
||||||
|
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||||
|
await rl.start()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(on_disconnect_called) == 0
|
||||||
|
assert len(on_connect_called) == 0
|
||||||
|
assert len(on_connect_fail_called) == 1
|
||||||
|
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||||
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
|
||||||
|
with patch.object(cli, "start_connection"), patch.object(
|
||||||
|
cli, "finish_connection", side_effect=APIConnectionError
|
||||||
|
):
|
||||||
|
await rl.start()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(on_disconnect_called) == 0
|
||||||
|
assert len(on_connect_called) == 0
|
||||||
|
assert len(on_connect_fail_called) == 2
|
||||||
|
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||||
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
|
||||||
|
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||||
|
await rl.start()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(on_disconnect_called) == 0
|
||||||
|
assert len(on_connect_called) == 1
|
||||||
|
assert len(on_connect_fail_called) == 2
|
||||||
|
assert rl._connection_state is ReconnectLogicState.READY
|
||||||
|
|
||||||
|
await rl.stop()
|
||||||
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_retry():
|
||||||
|
"""Test that reconnect logic retry."""
|
||||||
|
on_disconnect_called = []
|
||||||
|
on_connect_called = []
|
||||||
|
on_connect_fail_called = []
|
||||||
|
|
||||||
|
class PatchableAPIClient(APIClient):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cli = PatchableAPIClient(
|
||||||
|
address="1.2.3.4",
|
||||||
|
port=6052,
|
||||||
|
password=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||||
|
nonlocal on_disconnect_called
|
||||||
|
on_disconnect_called.append(expected_disconnect)
|
||||||
|
|
||||||
|
async def on_connect() -> None:
|
||||||
|
nonlocal on_connect_called
|
||||||
|
on_connect_called.append(True)
|
||||||
|
|
||||||
|
async def on_connect_fail(connect_exception: Exception) -> None:
|
||||||
|
nonlocal on_connect_called
|
||||||
|
on_connect_fail_called.append(connect_exception)
|
||||||
|
|
||||||
|
rl = ReconnectLogic(
|
||||||
|
client=cli,
|
||||||
|
on_disconnect=on_disconnect,
|
||||||
|
on_connect=on_connect,
|
||||||
|
zeroconf_instance=_get_mock_zeroconf(),
|
||||||
|
name="mydevice",
|
||||||
|
on_connect_error=on_connect_fail,
|
||||||
|
)
|
||||||
|
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||||
|
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||||
|
|
||||||
|
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||||
|
await rl.start()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(on_disconnect_called) == 0
|
||||||
|
assert len(on_connect_called) == 0
|
||||||
|
assert len(on_connect_fail_called) == 1
|
||||||
|
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||||
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
|
||||||
|
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||||
|
# Should now retry
|
||||||
|
assert rl._connect_timer is not None
|
||||||
|
rl._connect_timer._run()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(on_disconnect_called) == 0
|
||||||
|
assert len(on_connect_called) == 1
|
||||||
|
assert len(on_connect_fail_called) == 1
|
||||||
|
assert rl._connection_state is ReconnectLogicState.READY
|
||||||
|
|
||||||
|
await rl.stop()
|
||||||
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
Loading…
Reference in New Issue
Block a user