Fix zeroconf reconnect logic (#613)

This commit is contained in:
J. Nick Koston 2023-11-06 15:04:09 -06:00 committed by GitHub
parent 8357a3a0c6
commit 2ef9ed9026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 12 deletions

View File

@ -7,6 +7,8 @@ from enum import Enum
from typing import Callable
import zeroconf
from zeroconf.const import _TYPE_A as TYPE_A
from zeroconf.const import _TYPE_PTR as TYPE_PTR
from .client import APIClient
from .core import (
@ -21,7 +23,6 @@ _LOGGER = logging.getLogger(__name__)
EXPECTED_DISCONNECT_COOLDOWN = 5.0
MAXIMUM_BACKOFF_TRIES = 100
TYPE_PTR = 12
class ReconnectLogicState(Enum):
@ -87,7 +88,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance
self._filter_alias: str | None = None
self._ptr_alias: str | None = None
self._a_name: str | None = None
# Flag to check if the device is connected
self._connection_state = ReconnectLogicState.DISCONNECTED
self._accept_zeroconf_records = True
@ -149,7 +151,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
when the state is CONNECTING.
"""
self._connection_state = state
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
self._accept_zeroconf_records = state in NOT_YET_CONNECTED_STATES
def _async_log_connection_error(self, err: Exception) -> None:
"""Log connection errors."""
@ -332,7 +334,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""
if not self._zc_listening and self.name:
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
self._filter_alias = f"{self.name}._esphomelib._tcp.local."
self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
self._a_name = f"{self.name}.local."
self._zc.async_add_listener(self, None)
self._zc_listening = True
@ -355,17 +358,16 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""
# Check if already connected, no lock needed for this access and
# bail if either the already stopped or we haven't received device info yet
if (
not self._accept_zeroconf_records
or self._is_stopped
or self._filter_alias is None
):
if not self._accept_zeroconf_records or self._is_stopped:
return
for record_update in records:
# We only consider PTR records and match using the alias name
new_record = record_update.new
if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
if not (
(new_record.type == TYPE_PTR and new_record.alias == self._ptr_alias) # type: ignore[attr-defined]
or (new_record.type == TYPE_A and new_record.name == self._a_name)
):
continue
# Tell connection logic to retry connection attempt now (even before connect timer finishes)

View File

@ -1,14 +1,26 @@
import asyncio
from unittest.mock import MagicMock, patch
import logging
from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from zeroconf import Zeroconf
from zeroconf import (
DNSAddress,
DNSPointer,
DNSRecord,
RecordUpdate,
Zeroconf,
current_time_millis,
)
from zeroconf.asyncio import AsyncZeroconf
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
from aioesphomeapi import APIConnectionError
from aioesphomeapi.client import APIClient
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
def _get_mock_zeroconf() -> MagicMock:
return MagicMock(spec=Zeroconf)
@ -258,3 +270,104 @@ async def test_reconnect_retry():
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.parametrize(
("record", "should_trigger_zeroconf", "log_text"),
(
(
DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"mydevice._esphomelib._tcp.local.",
),
True,
"received mDNS record",
),
(
DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"wrong_name._esphomelib._tcp.local.",
),
False,
"",
),
(
DNSAddress(
"mydevice.local.",
_TYPE_A,
_CLASS_IN,
1000,
ip_address("1.2.3.4").packed,
),
True,
"received mDNS record",
),
),
)
@pytest.mark.asyncio
async def test_reconnect_zeroconf(
caplog: pytest.LogCaptureFixture,
record: DNSRecord,
should_trigger_zeroconf: bool,
log_text: str,
) -> None:
"""Test that reconnect logic retry."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
mock_zeroconf = MagicMock(spec=Zeroconf)
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
zeroconf_instance=mock_zeroconf,
name="mydevice",
on_connect_error=AsyncMock(),
)
assert rl._log_name == "mydevice @ 1.2.3.4"
assert cli._log_name == "mydevice @ 1.2.3.4"
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
async def quick_connect_fail(*args, **kwargs):
raise APIConnectionError
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(
cli, "start_connection", side_effect=slow_connect_fail
) as mock_start_connection:
await asyncio.sleep(0)
assert mock_start_connection.call_count == 0
rl.async_update_records(
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
)
await asyncio.sleep(0)
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
assert log_text in caplog.text
await rl.stop()