mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Fix race running disconnect callback in reconnect logic (#666)
This commit is contained in:
parent
df0dbadae7
commit
b8427c4cbb
@ -106,13 +106,21 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
|
|
||||||
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
||||||
"""Log and issue callbacks when disconnecting."""
|
"""Log and issue callbacks when disconnecting."""
|
||||||
if self._is_stopped:
|
|
||||||
return
|
|
||||||
# This can happen often depending on WiFi signal strength.
|
# This can happen often depending on WiFi signal strength.
|
||||||
# So therefore all these connection warnings are logged
|
# So therefore all these connection warnings are logged
|
||||||
# as infos. The "unavailable" logic will still trigger so the
|
# as infos. The "unavailable" logic will still trigger so the
|
||||||
# user knows if the device is not connected.
|
# user knows if the device is not connected.
|
||||||
disconnect_type = "expected" if expected_disconnect else "unexpected"
|
if expected_disconnect:
|
||||||
|
# If we expected the disconnect we need
|
||||||
|
# to cooldown before connecting in case the remote
|
||||||
|
# is rebooting so we don't establish a connection right
|
||||||
|
# before its about to reboot in the event we are too fast.
|
||||||
|
disconnect_type = "expected"
|
||||||
|
wait = EXPECTED_DISCONNECT_COOLDOWN
|
||||||
|
else:
|
||||||
|
disconnect_type = "unexpected"
|
||||||
|
wait = 0
|
||||||
|
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"Processing %s disconnect from ESPHome API for %s",
|
"Processing %s disconnect from ESPHome API for %s",
|
||||||
disconnect_type,
|
disconnect_type,
|
||||||
@ -120,21 +128,14 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run disconnect hook
|
# Run disconnect hook
|
||||||
await self._on_disconnect_cb(expected_disconnect)
|
|
||||||
|
|
||||||
await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED)
|
|
||||||
|
|
||||||
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
|
||||||
# If we expected the disconnect we need
|
|
||||||
# to cooldown before connecting in case the remote
|
|
||||||
# is rebooting so we don't establish a connection right
|
|
||||||
# before its about to reboot in the event we are too fast.
|
|
||||||
self._schedule_connect(wait)
|
|
||||||
|
|
||||||
async def _async_set_connection_state(self, state: ReconnectLogicState) -> None:
|
|
||||||
"""Set the connection state."""
|
|
||||||
async with self._connected_lock:
|
async with self._connected_lock:
|
||||||
self._async_set_connection_state_while_locked(state)
|
self._async_set_connection_state_while_locked(
|
||||||
|
ReconnectLogicState.DISCONNECTED
|
||||||
|
)
|
||||||
|
await self._on_disconnect_cb(expected_disconnect)
|
||||||
|
|
||||||
|
if not self._is_stopped:
|
||||||
|
self._schedule_connect(wait)
|
||||||
|
|
||||||
def _async_set_connection_state_while_locked(
|
def _async_set_connection_state_while_locked(
|
||||||
self, state: ReconnectLogicState
|
self, state: ReconnectLogicState
|
||||||
|
@ -18,10 +18,17 @@ from zeroconf.asyncio import AsyncZeroconf
|
|||||||
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
|
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
|
||||||
|
|
||||||
from aioesphomeapi import APIConnectionError
|
from aioesphomeapi import APIConnectionError
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi.client import APIClient
|
from aioesphomeapi.client import APIClient
|
||||||
|
from aioesphomeapi.connection import APIConnection
|
||||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||||
|
|
||||||
from .common import get_mock_zeroconf
|
from .common import (
|
||||||
|
get_mock_async_zeroconf,
|
||||||
|
get_mock_zeroconf,
|
||||||
|
send_plaintext_connect_response,
|
||||||
|
send_plaintext_hello,
|
||||||
|
)
|
||||||
|
|
||||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@ -443,3 +450,77 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake():
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert rl._is_stopped is True
|
assert rl._is_stopped is True
|
||||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventLoop):
|
||||||
|
"""Test the disconnect callback fires with expected_disconnect=False."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
protocol: APIPlaintextFrameHelper | None = None
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
|
||||||
|
class PatchableAPIClient(APIClient):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async_zeroconf = get_mock_async_zeroconf()
|
||||||
|
|
||||||
|
cli = PatchableAPIClient(
|
||||||
|
address="1.2.3.4",
|
||||||
|
port=6052,
|
||||||
|
password=None,
|
||||||
|
noise_psk=None,
|
||||||
|
expected_name="fake",
|
||||||
|
zeroconf_instance=async_zeroconf.zeroconf,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||||
|
nonlocal protocol
|
||||||
|
protocol = create_func()
|
||||||
|
protocol.connection_made(transport)
|
||||||
|
connected.set()
|
||||||
|
return transport, protocol
|
||||||
|
|
||||||
|
connected = asyncio.Event()
|
||||||
|
on_disconnect_calls = []
|
||||||
|
|
||||||
|
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||||
|
on_disconnect_calls.append(expected_disconnect)
|
||||||
|
|
||||||
|
async def on_connect() -> None:
|
||||||
|
connected.set()
|
||||||
|
|
||||||
|
logic = ReconnectLogic(
|
||||||
|
client=cli,
|
||||||
|
on_connect=on_connect,
|
||||||
|
on_disconnect=on_disconnect,
|
||||||
|
zeroconf_instance=async_zeroconf,
|
||||||
|
name="fake",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
|
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||||
|
):
|
||||||
|
await logic.start()
|
||||||
|
await connected.wait()
|
||||||
|
protocol = cli._connection._frame_helper
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
await connected.wait()
|
||||||
|
|
||||||
|
assert cli._connection.is_connected is True
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
|
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||||
|
) as mock_create_connection:
|
||||||
|
protocol.eof_received()
|
||||||
|
# Wait for the task to run
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Ensure we try to reconnect immediately
|
||||||
|
# since its an unexpected disconnect
|
||||||
|
assert mock_create_connection.call_count == 0
|
||||||
|
|
||||||
|
assert len(on_disconnect_calls) == 1
|
||||||
|
assert on_disconnect_calls[0] is False
|
||||||
|
await logic.stop()
|
||||||
|
Loading…
Reference in New Issue
Block a user