mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-28 04:27:27 +02:00
Merge branch 'main' into feature/fan_presets
This commit is contained in:
commit
2ff1b1e65e
@ -5,6 +5,7 @@ omit =
|
||||
aioesphomeapi/api_options_pb2.py
|
||||
aioesphomeapi/api_pb2.py
|
||||
aioesphomeapi/log_reader.py
|
||||
aioesphomeapi/discover.py
|
||||
bench/*.py
|
||||
|
||||
[report]
|
||||
|
@ -1,9 +1,41 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
---
|
||||
exclude: '^aioesphomeapi/api.*$'
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.37.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py37-plus]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
args:
|
||||
- --fix
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.11.0
|
||||
hooks:
|
||||
- id: black
|
||||
args:
|
||||
- --quiet
|
||||
files: ^((aioesphomeapi|tests)/.+)?[^/]+\.py$
|
||||
- repo: https://github.com/cdce8p/python-typing-update
|
||||
rev: v0.6.0
|
||||
hooks:
|
||||
- id: python-typing-update
|
||||
stages: [manual]
|
||||
args:
|
||||
- --py39-plus
|
||||
- --force
|
||||
- --keep-updates
|
||||
files: ^(aioesphomeapi)/.+\.py$
|
||||
|
@ -135,6 +135,12 @@ A cli tool is also available for watching logs:
|
||||
|
||||
aioesphomeapi-logs --help
|
||||
|
||||
A cli tool is also available to discover devices:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
aioesphomeapi-discover
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
|
@ -42,7 +42,7 @@ class APIFrameHelper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "APIConnection",
|
||||
connection: APIConnection,
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
|
@ -83,7 +83,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "APIConnection",
|
||||
connection: APIConnection,
|
||||
noise_psk: str,
|
||||
expected_name: str | None,
|
||||
client_info: str,
|
||||
|
80
aioesphomeapi/discover.py
Normal file
80
aioesphomeapi/discover.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Helper script and aioesphomeapi to discover api devices
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from zeroconf import IPVersion, ServiceStateChange, Zeroconf
|
||||
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
FORMAT = "{: <7}|{: <32}|{: <15}|{: <12}|{: <16}|{: <10}|{: <32}"
|
||||
COLUMN_NAMES = ("Status", "Name", "Address", "MAC", "Version", "Platform", "Board")
|
||||
|
||||
|
||||
def decode_bytes_or_none(data: str | bytes | None) -> str | None:
|
||||
"""Decode bytes or return None."""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
return data.decode()
|
||||
return data
|
||||
|
||||
|
||||
def async_service_update(
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
name: str,
|
||||
state_change: ServiceStateChange,
|
||||
) -> None:
|
||||
"""Service state changed."""
|
||||
short_name = name.partition(".")[0]
|
||||
if state_change is ServiceStateChange.Removed:
|
||||
state = "OFFLINE"
|
||||
else:
|
||||
state = "ONLINE"
|
||||
info = AsyncServiceInfo(service_type, name)
|
||||
info.load_from_cache(zeroconf)
|
||||
properties = info.properties
|
||||
mac = decode_bytes_or_none(properties.get(b"mac"))
|
||||
version = decode_bytes_or_none(properties.get(b"version"))
|
||||
platform = decode_bytes_or_none(properties.get(b"platform"))
|
||||
board = decode_bytes_or_none(properties.get(b"board"))
|
||||
address = None
|
||||
if addresses := info.ip_addresses_by_version(IPVersion.V4Only):
|
||||
address = str(addresses[0])
|
||||
|
||||
print(FORMAT.format(state, short_name, address, mac, version, platform, board))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
|
||||
level=logging.INFO,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
aiozc = AsyncZeroconf()
|
||||
browser = AsyncServiceBrowser(
|
||||
aiozc.zeroconf, "_esphomelib._tcp.local.", handlers=[async_service_update]
|
||||
)
|
||||
print(FORMAT.format(*COLUMN_NAMES))
|
||||
print("-" * 120)
|
||||
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
await browser.async_cancel()
|
||||
await aiozc.async_close()
|
||||
|
||||
|
||||
def cli_entry_point() -> None:
|
||||
"""Run the CLI."""
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_entry_point()
|
||||
sys.exit(0)
|
@ -58,6 +58,7 @@ async def _async_zeroconf_get_service_info(
|
||||
timeout: float,
|
||||
) -> AsyncServiceInfo:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
had_instance = zeroconf_manager.has_instance
|
||||
try:
|
||||
zc = zeroconf_manager.get_async_zeroconf().zeroconf
|
||||
except Exception as exc:
|
||||
@ -73,7 +74,8 @@ async def _async_zeroconf_get_service_info(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
await zeroconf_manager.async_close()
|
||||
if not had_instance:
|
||||
await zeroconf_manager.async_close()
|
||||
return info
|
||||
|
||||
|
||||
|
@ -39,7 +39,10 @@ async def main(argv: list[str]) -> None:
|
||||
time_ = datetime.now()
|
||||
message: bytes = msg.message
|
||||
text = message.decode("utf8", "backslashreplace")
|
||||
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
|
||||
nanoseconds = time_.microsecond // 1000
|
||||
print(
|
||||
f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}"
|
||||
)
|
||||
|
||||
stop = await async_run(cli, on_log)
|
||||
try:
|
||||
|
@ -19,7 +19,7 @@ from .core import (
|
||||
RequiresEncryptionAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
)
|
||||
from .util import address_is_local
|
||||
from .util import address_is_local, host_is_name_part
|
||||
from .zeroconf import ZeroconfInstanceType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -79,7 +79,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self.name: str | None = None
|
||||
if name:
|
||||
self.name = name
|
||||
elif address_is_local(client.address):
|
||||
elif host_is_name_part(client.address) or address_is_local(client.address):
|
||||
self.name = client.address.partition(".")[0]
|
||||
if self.name:
|
||||
self._cli.set_cached_name_if_unset(self.name)
|
||||
@ -93,7 +93,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._a_name: str | None = None
|
||||
# Flag to check if the device is connected
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._accept_zeroconf_records = True
|
||||
self._accept_zeroconf_records: bool = True
|
||||
self._connected_lock = asyncio.Lock()
|
||||
self._is_stopped = True
|
||||
self._zc_listening = False
|
||||
@ -226,11 +226,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
def _schedule_connect(self, delay: float) -> None:
|
||||
"""Schedule a connect attempt."""
|
||||
self._cancel_connect("Scheduling new connect attempt")
|
||||
if not delay:
|
||||
self._call_connect_once()
|
||||
return
|
||||
_LOGGER.debug("Scheduling new connect attempt in %f seconds", delay)
|
||||
_LOGGER.debug("Scheduling new connect attempt in %.2f seconds", delay)
|
||||
self._cancel_connect_timer()
|
||||
self._connect_timer = self.loop.call_at(
|
||||
self.loop.time() + delay, self._call_connect_once
|
||||
)
|
||||
@ -240,17 +240,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
Must only be called from _schedule_connect.
|
||||
"""
|
||||
if self._connect_task:
|
||||
if self._connect_task and not self._connect_task.done():
|
||||
if self._connection_state != ReconnectLogicState.CONNECTING:
|
||||
# Connection state is far enough along that we should
|
||||
# not restart the connect task
|
||||
_LOGGER.debug(
|
||||
"%s: Not cancelling existing connect task as its already %s!",
|
||||
self._cli.log_name,
|
||||
self._connection_state,
|
||||
)
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Cancelling existing connect task, to try again now!",
|
||||
"%s: Cancelling existing connect task with state %s, to try again now!",
|
||||
self._cli.log_name,
|
||||
self._connection_state,
|
||||
)
|
||||
self._connect_task.cancel("Scheduling new connect attempt")
|
||||
self._connect_task = None
|
||||
self._cancel_connect_task("Scheduling new connect attempt")
|
||||
self._async_set_connection_state_without_lock(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
@ -260,15 +265,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
name=f"{self._cli.log_name}: aioesphomeapi connect",
|
||||
)
|
||||
|
||||
def _cancel_connect(self, msg: str) -> None:
|
||||
"""Cancel the connect."""
|
||||
def _cancel_connect_timer(self) -> None:
|
||||
"""Cancel the connect timer."""
|
||||
if self._connect_timer:
|
||||
self._connect_timer.cancel()
|
||||
self._connect_timer = None
|
||||
|
||||
def _cancel_connect_task(self, msg: str) -> None:
|
||||
"""Cancel the connect task."""
|
||||
if self._connect_task:
|
||||
self._connect_task.cancel(msg)
|
||||
self._connect_task = None
|
||||
|
||||
def _cancel_connect(self, msg: str) -> None:
|
||||
"""Cancel the connect."""
|
||||
self._cancel_connect_timer()
|
||||
self._cancel_connect_task(msg)
|
||||
|
||||
async def _connect_once_or_reschedule(self) -> None:
|
||||
"""Connect once or schedule connect.
|
||||
|
||||
@ -290,7 +303,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
_LOGGER.info(
|
||||
"Trying to connect to %s in the background", self._cli.log_name
|
||||
)
|
||||
_LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time)
|
||||
_LOGGER.debug("Retrying %s in %.2f seconds", self._cli.log_name, wait_time)
|
||||
if wait_time:
|
||||
# If we are waiting, start listening for mDNS records
|
||||
self._start_zc_listen()
|
||||
@ -365,6 +378,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
)
|
||||
self._zc_listening = False
|
||||
|
||||
def _connect_from_zeroconf(self) -> None:
|
||||
"""Connect from zeroconf."""
|
||||
self._stop_zc_listen()
|
||||
self._schedule_connect(0.0)
|
||||
|
||||
def async_update_records(
|
||||
self,
|
||||
zc: zeroconf.Zeroconf, # pylint: disable=unused-argument
|
||||
@ -398,7 +416,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# We can't stop the zeroconf listener here because we are in the middle of
|
||||
# a zeroconf callback which is iterating the listeners.
|
||||
#
|
||||
# So we schedule a stop for the next event loop iteration.
|
||||
self.loop.call_soon(self._stop_zc_listen)
|
||||
self._schedule_connect(0.0)
|
||||
# So we schedule a stop for the next event loop iteration as well as the
|
||||
# connect attempt.
|
||||
#
|
||||
# If we scheduled the connect attempt immediately, the listener could fire
|
||||
# again before the connect attempt and we cancel and reschedule the connect
|
||||
# attempt again.
|
||||
#
|
||||
self.loop.call_soon(self._connect_from_zeroconf)
|
||||
self._accept_zeroconf_records = False
|
||||
return
|
||||
|
@ -26,6 +26,11 @@ class ZeroconfManager:
|
||||
if zeroconf is not None:
|
||||
self.set_instance(zeroconf)
|
||||
|
||||
@property
|
||||
def has_instance(self) -> bool:
|
||||
"""Return True if a Zeroconf instance is set."""
|
||||
return self._aiozc is not None
|
||||
|
||||
def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
|
||||
"""Set the AsyncZeroconf instance."""
|
||||
if self._aiozc:
|
||||
|
@ -30,5 +30,10 @@ disable = [
|
||||
"too-many-lines",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
ignore = [
|
||||
"E721", # We want type() check for protobuf messages
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ['setuptools>=65.4.1', 'wheel', 'Cython>=3.0.2']
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subprocess import check_call
|
||||
from pathlib import Path
|
||||
import os
|
||||
from pathlib import Path
|
||||
from subprocess import check_call
|
||||
|
||||
root_dir = Path(__file__).absolute().parent.parent
|
||||
os.chdir(root_dir)
|
||||
|
13
setup.py
13
setup.py
@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file:
|
||||
long_description = readme_file.read()
|
||||
|
||||
|
||||
VERSION = "19.0.2"
|
||||
VERSION = "19.1.0"
|
||||
PROJECT_NAME = "aioesphomeapi"
|
||||
PROJECT_PACKAGE_NAME = "aioesphomeapi"
|
||||
PROJECT_LICENSE = "MIT"
|
||||
@ -23,11 +23,11 @@ PROJECT_EMAIL = "esphome@nabucasa.com"
|
||||
PROJECT_GITHUB_USERNAME = "esphome"
|
||||
PROJECT_GITHUB_REPOSITORY = "aioesphomeapi"
|
||||
|
||||
PYPI_URL = "https://pypi.python.org/pypi/{}".format(PROJECT_PACKAGE_NAME)
|
||||
GITHUB_PATH = "{}/{}".format(PROJECT_GITHUB_USERNAME, PROJECT_GITHUB_REPOSITORY)
|
||||
GITHUB_URL = "https://github.com/{}".format(GITHUB_PATH)
|
||||
PYPI_URL = f"https://pypi.python.org/pypi/{PROJECT_PACKAGE_NAME}"
|
||||
GITHUB_PATH = f"{PROJECT_GITHUB_USERNAME}/{PROJECT_GITHUB_REPOSITORY}"
|
||||
GITHUB_URL = f"https://github.com/{GITHUB_PATH}"
|
||||
|
||||
DOWNLOAD_URL = "{}/archive/{}.zip".format(GITHUB_URL, VERSION)
|
||||
DOWNLOAD_URL = f"{GITHUB_URL}/archive/{VERSION}.zip"
|
||||
|
||||
MODULES_TO_CYTHONIZE = [
|
||||
"aioesphomeapi/client_callbacks.py",
|
||||
@ -61,7 +61,8 @@ setup_kwargs = {
|
||||
"test_suite": "tests",
|
||||
"entry_points": {
|
||||
"console_scripts": [
|
||||
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
|
||||
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point",
|
||||
"aioesphomeapi-discover=aioesphomeapi.discover:cli_entry_point",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
@ -49,6 +49,19 @@ def socket_socket():
|
||||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patchable_api_client() -> APIClient:
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
return cli
|
||||
|
||||
|
||||
def get_mock_connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
address="fake.address",
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from ipaddress import ip_address
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -17,10 +18,14 @@ from zeroconf import (
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
|
||||
|
||||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi import APIConnectionError, RequiresEncryptionAPIError
|
||||
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
from aioesphomeapi.reconnect_logic import (
|
||||
MAXIMUM_BACKOFF_TRIES,
|
||||
ReconnectLogic,
|
||||
ReconnectLogicState,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
get_mock_async_zeroconf,
|
||||
@ -28,10 +33,20 @@ from .common import (
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
)
|
||||
from .conftest import _create_mock_transport_protocol
|
||||
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
|
||||
async def quick_connect_fail(*args, **kwargs):
|
||||
raise APIConnectionError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_name_from_host():
|
||||
"""Test that the name is set correctly from the host."""
|
||||
@ -71,13 +86,14 @@ async def test_reconnect_logic_name_from_host_and_set():
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
ReconnectLogic(
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
)
|
||||
assert rl.name == "mydevice"
|
||||
assert cli.log_name == "mydevice.local"
|
||||
|
||||
|
||||
@ -131,20 +147,38 @@ async def test_reconnect_logic_name_from_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_state():
|
||||
async def test_reconnect_logic_name_from_cli_address():
|
||||
"""Test that the name is set correctly from the address."""
|
||||
cli = APIClient(
|
||||
address="mydevice",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
pass
|
||||
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=get_mock_zeroconf(),
|
||||
)
|
||||
assert cli.log_name == "mydevice"
|
||||
assert rl.name == "mydevice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_state(patchable_api_client: APIClient):
|
||||
"""Test that reconnect logic state changes."""
|
||||
on_disconnect_called = []
|
||||
on_connect_called = []
|
||||
on_connect_fail_called = []
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
cli = patchable_api_client
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
nonlocal on_disconnect_called
|
||||
@ -178,9 +212,10 @@ async def test_reconnect_logic_state():
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._tries == 1
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(
|
||||
cli, "finish_connection", side_effect=APIConnectionError
|
||||
cli, "finish_connection", side_effect=RequiresEncryptionAPIError
|
||||
):
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
@ -189,8 +224,9 @@ async def test_reconnect_logic_state():
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 0
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert isinstance(on_connect_fail_called[-1], RequiresEncryptionAPIError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._tries == MAXIMUM_BACKOFF_TRIES
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||
await rl.start()
|
||||
@ -201,26 +237,20 @@ async def test_reconnect_logic_state():
|
||||
assert len(on_connect_called) == 1
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert rl._connection_state is ReconnectLogicState.READY
|
||||
|
||||
assert rl._tries == 0
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_retry():
|
||||
async def test_reconnect_retry(
|
||||
patchable_api_client: APIClient, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that reconnect logic retry."""
|
||||
on_disconnect_called = []
|
||||
on_connect_called = []
|
||||
on_connect_fail_called = []
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
cli = patchable_api_client
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
nonlocal on_disconnect_called
|
||||
@ -243,6 +273,7 @@ async def test_reconnect_retry():
|
||||
on_connect_error=on_connect_fail,
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
caplog.clear()
|
||||
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
await rl.start()
|
||||
@ -255,35 +286,70 @@ async def test_reconnect_retry():
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert "connect to ESPHome API for mydevice @ 1.2.3.4" in caplog.text
|
||||
for record in caplog.records:
|
||||
if "connect to ESPHome API for mydevice @ 1.2.3.4" in record.message:
|
||||
assert record.levelno == logging.WARNING
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||
caplog.clear()
|
||||
# Next retry should run at debug level
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
# Should now retry
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 0
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert "connect to ESPHome API for mydevice @ 1.2.3.4" in caplog.text
|
||||
for record in caplog.records:
|
||||
if "connect to ESPHome API for mydevice @ 1.2.3.4" in record.message:
|
||||
assert record.levelno == logging.DEBUG
|
||||
|
||||
caplog.clear()
|
||||
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
|
||||
# Should now retry
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert "connect to ESPHome API for mydevice @ 1.2.3.4" not in caplog.text
|
||||
assert len(on_disconnect_called) == 0
|
||||
assert len(on_connect_called) == 1
|
||||
assert len(on_connect_fail_called) == 1
|
||||
assert len(on_connect_fail_called) == 2
|
||||
assert rl._connection_state is ReconnectLogicState.READY
|
||||
original_when = rl._connect_timer.when()
|
||||
|
||||
# Ensure starting the connection logic again does not trigger a new connection
|
||||
await rl.start()
|
||||
# Verify no new timer is started
|
||||
assert rl._connect_timer.when() == original_when
|
||||
|
||||
await rl.stop()
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
DNS_POINTER = DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"mydevice._esphomelib._tcp.local.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("record", "should_trigger_zeroconf", "log_text"),
|
||||
("record", "should_trigger_zeroconf", "expected_state_after_trigger", "log_text"),
|
||||
(
|
||||
(
|
||||
DNSPointer(
|
||||
"_esphomelib._tcp.local.",
|
||||
_TYPE_PTR,
|
||||
_CLASS_IN,
|
||||
1000,
|
||||
"mydevice._esphomelib._tcp.local.",
|
||||
),
|
||||
DNS_POINTER,
|
||||
True,
|
||||
ReconnectLogicState.READY,
|
||||
"received mDNS record",
|
||||
),
|
||||
(
|
||||
@ -295,6 +361,7 @@ async def test_reconnect_retry():
|
||||
"wrong_name._esphomelib._tcp.local.",
|
||||
),
|
||||
False,
|
||||
ReconnectLogicState.CONNECTING,
|
||||
"",
|
||||
),
|
||||
(
|
||||
@ -306,27 +373,23 @@ async def test_reconnect_retry():
|
||||
ip_address("1.2.3.4").packed,
|
||||
),
|
||||
True,
|
||||
ReconnectLogicState.READY,
|
||||
"received mDNS record",
|
||||
),
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_zeroconf(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
record: DNSRecord,
|
||||
should_trigger_zeroconf: bool,
|
||||
expected_state_after_trigger: ReconnectLogicState,
|
||||
log_text: str,
|
||||
) -> None:
|
||||
"""Test that reconnect logic retry."""
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
cli = patchable_api_client
|
||||
|
||||
mock_zeroconf = MagicMock(spec=Zeroconf)
|
||||
|
||||
@ -340,13 +403,6 @@ async def test_reconnect_zeroconf(
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
async def quick_connect_fail(*args, **kwargs):
|
||||
raise APIConnectionError
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
@ -358,30 +414,203 @@ async def test_reconnect_zeroconf(
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=slow_connect_fail
|
||||
) as mock_start_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.CONNECTING
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert mock_start_connection.call_count == 0
|
||||
|
||||
caplog.clear()
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection"
|
||||
):
|
||||
assert rl._zc_listening is True
|
||||
rl.async_update_records(
|
||||
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
|
||||
)
|
||||
assert (
|
||||
"Triggering connect because of received mDNS record" in caplog.text
|
||||
) is should_trigger_zeroconf
|
||||
assert rl._accept_zeroconf_records is not should_trigger_zeroconf
|
||||
assert rl._zc_listening is True # should change after one iteration of the loop
|
||||
await asyncio.sleep(0)
|
||||
assert rl._zc_listening is not should_trigger_zeroconf
|
||||
|
||||
# The reconnect is scheduled to run in the next loop iteration
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
|
||||
assert log_text in caplog.text
|
||||
|
||||
assert rl._connection_state is expected_state_after_trigger
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_stop_callback():
|
||||
"""Test that the stop_callback stops the ReconnectLogic."""
|
||||
cli = APIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
async def test_reconnect_zeroconf_not_while_handshaking(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that reconnect logic retry will not trigger a zeroconf reconnect while handshaking."""
|
||||
cli = patchable_api_client
|
||||
|
||||
mock_zeroconf = MagicMock(spec=Zeroconf)
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
zeroconf_instance=mock_zeroconf,
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
) as mock_finish_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert mock_finish_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
|
||||
assert rl._accept_zeroconf_records is False
|
||||
assert not rl._is_stopped
|
||||
|
||||
rl.async_update_records(
|
||||
mock_zeroconf, current_time_millis(), [RecordUpdate(DNS_POINTER, None)]
|
||||
)
|
||||
assert (
|
||||
"Triggering connect because of received mDNS record" in caplog.text
|
||||
) is False
|
||||
|
||||
rl._cancel_connect("forced cancel in test")
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_task_not_cancelled_while_handshaking(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that reconnect logic will not cancel an in progress handshake."""
|
||||
cli = patchable_api_client
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
) as mock_finish_connection:
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
assert rl._accept_zeroconf_records is True
|
||||
assert not rl._is_stopped
|
||||
|
||||
assert rl._connect_timer is not None
|
||||
rl._connect_timer._run()
|
||||
await asyncio.sleep(0)
|
||||
assert mock_start_connection.call_count == 1
|
||||
assert mock_finish_connection.call_count == 1
|
||||
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
|
||||
assert rl._accept_zeroconf_records is False
|
||||
assert not rl._is_stopped
|
||||
|
||||
caplog.clear()
|
||||
# This can likely never happen in practice, but we should handle it
|
||||
# in the event there is a race as the consequence is that we could
|
||||
# disconnect a working connection.
|
||||
rl._call_connect_once()
|
||||
assert (
|
||||
"Not cancelling existing connect task as its already ReconnectLogicState.HANDSHAKING"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
rl._cancel_connect("forced cancel in test")
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_aborts_if_stopped(
|
||||
patchable_api_client: APIClient,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that reconnect logic will abort connecting if stopped."""
|
||||
cli = patchable_api_client
|
||||
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
on_connect=AsyncMock(),
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(
|
||||
cli, "start_connection", side_effect=quick_connect_fail
|
||||
) as mock_start_connection:
|
||||
await rl.start()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == 1
|
||||
|
||||
with patch.object(cli, "start_connection") as mock_start_connection:
|
||||
timer = rl._connect_timer
|
||||
assert timer is not None
|
||||
await rl.stop()
|
||||
assert rl._is_stopped is True
|
||||
rl._call_connect_once()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# We should never try to connect again
|
||||
# once we are stopped
|
||||
assert mock_start_connection.call_count == 0
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_stop_callback(patchable_api_client: APIClient):
|
||||
"""Test that the stop_callback stops the ReconnectLogic."""
|
||||
cli = patchable_api_client
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
@ -403,17 +632,11 @@ async def test_reconnect_logic_stop_callback():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_logic_stop_callback_waits_for_handshake():
|
||||
async def test_reconnect_logic_stop_callback_waits_for_handshake(
|
||||
patchable_api_client: APIClient,
|
||||
):
|
||||
"""Test that the stop_callback waits for a handshake."""
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
cli = patchable_api_client
|
||||
rl = ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=AsyncMock(),
|
||||
@ -423,10 +646,6 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake():
|
||||
)
|
||||
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
raise APIConnectionError
|
||||
|
||||
with patch.object(cli, "start_connection"), patch.object(
|
||||
cli, "finish_connection", side_effect=slow_connect_fail
|
||||
):
|
||||
@ -473,13 +692,6 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
||||
zeroconf_instance=async_zeroconf.zeroconf,
|
||||
)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
connected = asyncio.Event()
|
||||
on_disconnect_calls = []
|
||||
|
||||
@ -498,20 +710,23 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
||||
)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
await logic.start()
|
||||
await connected.wait()
|
||||
protocol = cli._connection._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
await connected.wait()
|
||||
|
||||
assert cli._connection.is_connected is True
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
) as mock_create_connection:
|
||||
protocol.eof_received()
|
||||
# Wait for the task to run
|
||||
|
Loading…
Reference in New Issue
Block a user