From d950d902455ad4185b5161b1588887c8086507b4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 14 Oct 2023 17:04:27 -1000 Subject: [PATCH] Improve consistency of name logging (#577) --- aioesphomeapi/client.py | 32 +++++++-- aioesphomeapi/reconnect_logic.py | 5 +- tests/test_reconnect_logic.py | 109 +++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 tests/test_reconnect_logic.py diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 068fb8d..2ca0e23 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -239,7 +239,14 @@ ExecuteServiceDataType = dict[ # pylint: disable=too-many-public-methods class APIClient: - __slots__ = ("_params", "_connection", "_cached_name", "_background_tasks", "_loop") + __slots__ = ( + "_params", + "_connection", + "_cached_name", + "_background_tasks", + "_loop", + "_log_name", + ) def __init__( self, @@ -283,6 +290,7 @@ class APIClient: self._cached_name: str | None = None self._background_tasks: set[asyncio.Task[Any]] = set() self._loop = asyncio.get_event_loop() + self._set_log_name() @property def expected_name(self) -> str | None: @@ -296,16 +304,27 @@ class APIClient: def address(self) -> str: return self._params.address - @property - def _log_name(self) -> str: - if self._cached_name is not None and not self.address.endswith(".local"): - return f"{self._cached_name} @ {self.address}" - return self.address + def _get_log_name(self) -> str: + """Get the log name of the device.""" + address = self.address + address_is_host = address.endswith(".local") + if self._cached_name is not None: + if address_is_host: + return self._cached_name + return f"{self._cached_name} @ {address}" + if address_is_host: + return address[:-6] + return address + + def _set_log_name(self) -> None: + """Set the log name of the device.""" + self._log_name = self._get_log_name() def set_cached_name_if_unset(self, name: str) -> None: """Set the cached name of the device if not set.""" if not self._cached_name: self._cached_name = name + self._set_log_name() async def connect( self, @@ -390,6 +409,7 @@ class APIClient: info = DeviceInfo.from_pb(resp) self._cached_name = info.name connection.set_log_name(self._log_name) + self._set_log_name() return info async def list_entities_services( diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index c9e23af..55595f8 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -78,7 +78,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): self._log_name = self.name elif name: self.name = name - self._log_name = f"{self.name} @ {self._cli.address}" + self._log_name = f"{name} @ {self._cli.address}" + self._cli.set_cached_name_if_unset(name) else: self.name = None self._log_name = client.address @@ -276,8 +277,6 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): async def start(self) -> None: """Start the connecting logic background task.""" - if self.name: - self._cli.set_cached_name_if_unset(self.name) async with self._connected_lock: self._is_stopped = False if self._connection_state != ReconnectLogicState.DISCONNECTED: diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py new file mode 100644 index 0000000..eb9162a --- /dev/null +++ b/tests/test_reconnect_logic.py @@ -0,0 +1,109 @@ +from unittest.mock import MagicMock + +import pytest +from zeroconf.asyncio import AsyncZeroconf + +from aioesphomeapi.client import APIClient +from aioesphomeapi.reconnect_logic import ReconnectLogic + + +@pytest.mark.asyncio +async def test_reconnect_logic_name_from_host(): + """Test that the name is set correctly from the host.""" + cli = APIClient( + address="mydevice.local", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + pass + + async def on_connect() -> None: + pass + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=MagicMock(spec=AsyncZeroconf), + ) + assert rl._log_name == "mydevice" + assert cli._log_name == "mydevice" + + +@pytest.mark.asyncio +async def test_reconnect_logic_name_from_host_and_set(): + """Test that the name is set correctly from the host.""" + cli = APIClient( + address="mydevice.local", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + pass + + async def on_connect() -> None: + pass + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=MagicMock(spec=AsyncZeroconf), + name="mydevice", + ) + assert rl._log_name == "mydevice" + assert cli._log_name == "mydevice" + + +@pytest.mark.asyncio +async def test_reconnect_logic_name_from_address(): + """Test that the name is set correctly from the address.""" + cli = APIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + pass + + async def on_connect() -> None: + pass + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=MagicMock(spec=AsyncZeroconf), + ) + assert rl._log_name == "1.2.3.4" + assert cli._log_name == "1.2.3.4" + + +@pytest.mark.asyncio +async def test_reconnect_logic_name_from_name(): + """Test that the name is set correctly from the address.""" + cli = APIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + pass + + async def on_connect() -> None: + pass + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=MagicMock(spec=AsyncZeroconf), + name="mydevice", + ) + assert rl._log_name == "mydevice @ 1.2.3.4" + assert cli._log_name == "mydevice @ 1.2.3.4"