mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-15 10:55:13 +01:00
Fix zeroconf reconnect logic (#613)
This commit is contained in:
parent
8357a3a0c6
commit
2ef9ed9026
@ -7,6 +7,8 @@ from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import zeroconf
|
||||
from zeroconf.const import _TYPE_A as TYPE_A
|
||||
from zeroconf.const import _TYPE_PTR as TYPE_PTR
|
||||
|
||||
from .client import APIClient
|
||||
from .core import (
|
||||
@ -21,7 +23,6 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
EXPECTED_DISCONNECT_COOLDOWN = 5.0
|
||||
MAXIMUM_BACKOFF_TRIES = 100
|
||||
TYPE_PTR = 12
|
||||
|
||||
|
||||
class ReconnectLogicState(Enum):
|
||||
@ -87,7 +88,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._on_disconnect_cb = on_disconnect
|
||||
self._on_connect_error_cb = on_connect_error
|
||||
self._zc = zeroconf_instance
|
||||
self._filter_alias: str | None = None
|
||||
self._ptr_alias: str | None = None
|
||||
self._a_name: str | None = None
|
||||
# Flag to check if the device is connected
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._accept_zeroconf_records = True
|
||||
@ -149,7 +151,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
when the state is CONNECTING.
|
||||
"""
|
||||
self._connection_state = state
|
||||
self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES
|
||||
self._accept_zeroconf_records = state in NOT_YET_CONNECTED_STATES
|
||||
|
||||
def _async_log_connection_error(self, err: Exception) -> None:
|
||||
"""Log connection errors."""
|
||||
@ -332,7 +334,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""
|
||||
if not self._zc_listening and self.name:
|
||||
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
|
||||
self._filter_alias = f"{self.name}._esphomelib._tcp.local."
|
||||
self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
|
||||
self._a_name = f"{self.name}.local."
|
||||
self._zc.async_add_listener(self, None)
|
||||
self._zc_listening = True
|
||||
|
||||
@ -355,17 +358,16 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""
|
||||
# Check if already connected, no lock needed for this access and
|
||||
# bail if either the already stopped or we haven't received device info yet
|
||||
if (
|
||||
not self._accept_zeroconf_records
|
||||
or self._is_stopped
|
||||
or self._filter_alias is None
|
||||
):
|
||||
if not self._accept_zeroconf_records or self._is_stopped:
|
||||
return
|
||||
|
||||
for record_update in records:
|
||||
# We only consider PTR records and match using the alias name
|
||||
new_record = record_update.new
|
||||
if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined]
|
||||
if not (
|
||||
(new_record.type == TYPE_PTR and new_record.alias == self._ptr_alias) # type: ignore[attr-defined]
|
||||
or (new_record.type == TYPE_A and new_record.name == self._a_name)
|
||||
):
|
||||
continue
|
||||
|
||||
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
||||
|
@ -1,14 +1,26 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
import logging
|
||||
from ipaddress import ip_address
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from zeroconf import Zeroconf
|
||||
from zeroconf import (
|
||||
DNSAddress,
|
||||
DNSPointer,
|
||||
DNSRecord,
|
||||
RecordUpdate,
|
||||
Zeroconf,
|
||||
current_time_millis,
|
||||
)
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
|
||||
|
||||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _get_mock_zeroconf() -> MagicMock:
|
||||
return MagicMock(spec=Zeroconf)
|
||||
@ -258,3 +270,104 @@ async def test_reconnect_retry():
|
||||
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("record", "should_trigger_zeroconf", "log_text"),
|
||||
(
|
||||
(
|
||||
DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"mydevice._esphomelib._tcp.local.",
|
||||
),
|
||||
True,
|
||||
"received mDNS record",
|
||||
),
|
||||
(
|
||||
DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"wrong_name._esphomelib._tcp.local.",
|
||||
),
|
||||
False,
|
||||
"",
|
||||
),
|
||||
(
|
||||
DNSAddress(
|
||||
"mydevice.local.",
|
||||
_TYPE_A,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
ip_address("1.2.3.4").packed,
|
||||
),
|
||||
True,
|
||||
"received mDNS record",
|
||||
),
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_zeroconf(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
record: DNSRecord,
|
||||
should_trigger_zeroconf: bool,
|
||||
log_text: str,
|
||||
) -> None:
|
||||
"""Test that reconnect logic retry."""
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
|
||||
mock_zeroconf = MagicMock(spec=Zeroconf)
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
zeroconf_instance=mock_zeroconf,
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
async def quick_connect_fail(*args, **kwargs):
|
||||
raise APIConnectionError
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=slow_connect_fail
|
||||
) as mock_start_connection:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 0
|
||||
|
||||
rl.async_update_records(
|
||||
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
|
||||
assert log_text in caplog.text
|
||||
|
||||
await rl.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user