Fix race running disconnect callback in reconnect logic (#666)

This commit is contained in:
J. Nick Koston 2023-11-23 15:39:03 +01:00 committed by GitHub
parent df0dbadae7
commit b8427c4cbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 18 deletions

View File

@ -106,13 +106,21 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async def _on_disconnect(self, expected_disconnect: bool) -> None: async def _on_disconnect(self, expected_disconnect: bool) -> None:
"""Log and issue callbacks when disconnecting.""" """Log and issue callbacks when disconnecting."""
if self._is_stopped:
return
# This can happen often depending on WiFi signal strength. # This can happen often depending on WiFi signal strength.
# So therefore all these connection warnings are logged # So therefore all these connection warnings are logged
# as infos. The "unavailable" logic will still trigger so the # as infos. The "unavailable" logic will still trigger so the
# user knows if the device is not connected. # user knows if the device is not connected.
disconnect_type = "expected" if expected_disconnect else "unexpected" if expected_disconnect:
# If we expected the disconnect we need
# to cooldown before connecting in case the remote
# is rebooting so we don't establish a connection right
# before its about to reboot in the event we are too fast.
disconnect_type = "expected"
wait = EXPECTED_DISCONNECT_COOLDOWN
else:
disconnect_type = "unexpected"
wait = 0
_LOGGER.info( _LOGGER.info(
"Processing %s disconnect from ESPHome API for %s", "Processing %s disconnect from ESPHome API for %s",
disconnect_type, disconnect_type,
@ -120,21 +128,14 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
) )
# Run disconnect hook # Run disconnect hook
await self._on_disconnect_cb(expected_disconnect)
await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
# If we expected the disconnect we need
# to cooldown before connecting in case the remote
# is rebooting so we don't establish a connection right
# before its about to reboot in the event we are too fast.
self._schedule_connect(wait)
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
"""Set the connection state."""
async with self._connected_lock: async with self._connected_lock:
self._async_set_connection_state_while_locked(state) self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
await self._on_disconnect_cb(expected_disconnect)
if not self._is_stopped:
self._schedule_connect(wait)
def _async_set_connection_state_while_locked( def _async_set_connection_state_while_locked(
self, state: ReconnectLogicState self, state: ReconnectLogicState

View File

@ -18,10 +18,17 @@ from zeroconf.asyncio import AsyncZeroconf
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
from aioesphomeapi import APIConnectionError from aioesphomeapi import APIConnectionError
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import get_mock_zeroconf from .common import (
get_mock_async_zeroconf,
get_mock_zeroconf,
send_plaintext_connect_response,
send_plaintext_hello,
)
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG) logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
@ -443,3 +450,77 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake():
await asyncio.sleep(0) await asyncio.sleep(0)
assert rl._is_stopped is True assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventLoop):
"""Test the disconnect callback fires with expected_disconnect=False."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
noise_psk=None,
expected_name="fake",
zeroconf_instance=async_zeroconf.zeroconf,
)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
connected = asyncio.Event()
on_disconnect_calls = []
async def on_disconnect(expected_disconnect: bool) -> None:
on_disconnect_calls.append(expected_disconnect)
async def on_connect() -> None:
connected.set()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=async_zeroconf,
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connected.wait()
assert cli._connection.is_connected is True
await asyncio.sleep(0)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
) as mock_create_connection:
protocol.eof_received()
# Wait for the task to run
await asyncio.sleep(0)
# Ensure we try to reconnect immediately
# since its an unexpected disconnect
assert mock_create_connection.call_count == 0
assert len(on_disconnect_calls) == 1
assert on_disconnect_calls[0] is False
await logic.stop()