Merge branch 'main' into climate_enhancements

This commit is contained in:
J. Nick Koston 2023-11-26 11:25:10 -06:00 committed by GitHub
commit 82989b05ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 504 additions and 117 deletions

View File

@ -5,6 +5,7 @@ omit =
aioesphomeapi/api_options_pb2.py
aioesphomeapi/api_pb2.py
aioesphomeapi/log_reader.py
aioesphomeapi/discover.py
bench/*.py
[report]

View File

@ -1,9 +1,41 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
---
exclude: '^aioesphomeapi/api.*$'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-added-large-files
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-added-large-files
- repo: https://github.com/asottile/pyupgrade
rev: v2.37.1
hooks:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.1
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
hooks:
- id: black
args:
- --quiet
files: ^((aioesphomeapi|tests)/.+)?[^/]+\.py$
- repo: https://github.com/cdce8p/python-typing-update
rev: v0.6.0
hooks:
- id: python-typing-update
stages: [manual]
args:
- --py39-plus
- --force
- --keep-updates
files: ^(aioesphomeapi)/.+\.py$

View File

@ -135,6 +135,12 @@ A cli tool is also available for watching logs:
aioesphomeapi-logs --help
A cli tool is also available to discover devices:
.. code:: bash
aioesphomeapi-discover
License
-------

View File

@ -42,7 +42,7 @@ class APIFrameHelper:
def __init__(
self,
connection: "APIConnection",
connection: APIConnection,
client_info: str,
log_name: str,
) -> None:

View File

@ -83,7 +83,7 @@ class APINoiseFrameHelper(APIFrameHelper):
def __init__(
self,
connection: "APIConnection",
connection: APIConnection,
noise_psk: str,
expected_name: str | None,
client_info: str,

80
aioesphomeapi/discover.py Normal file
View File

@ -0,0 +1,80 @@
from __future__ import annotations
# Helper script and aioesphomeapi to discover api devices
import asyncio
import logging
import sys
from zeroconf import IPVersion, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
FORMAT = "{: <7}|{: <32}|{: <15}|{: <12}|{: <16}|{: <10}|{: <32}"
COLUMN_NAMES = ("Status", "Name", "Address", "MAC", "Version", "Platform", "Board")
def decode_bytes_or_none(data: str | bytes | None) -> str | None:
"""Decode bytes or return None."""
if data is None:
return None
if isinstance(data, bytes):
return data.decode()
return data
def async_service_update(
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
"""Service state changed."""
short_name = name.partition(".")[0]
if state_change is ServiceStateChange.Removed:
state = "OFFLINE"
else:
state = "ONLINE"
info = AsyncServiceInfo(service_type, name)
info.load_from_cache(zeroconf)
properties = info.properties
mac = decode_bytes_or_none(properties.get(b"mac"))
version = decode_bytes_or_none(properties.get(b"version"))
platform = decode_bytes_or_none(properties.get(b"platform"))
board = decode_bytes_or_none(properties.get(b"board"))
address = None
if addresses := info.ip_addresses_by_version(IPVersion.V4Only):
address = str(addresses[0])
print(FORMAT.format(state, short_name, address, mac, version, platform, board))
async def main() -> None:
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
aiozc = AsyncZeroconf()
browser = AsyncServiceBrowser(
aiozc.zeroconf, "_esphomelib._tcp.local.", handlers=[async_service_update]
)
print(FORMAT.format(*COLUMN_NAMES))
print("-" * 120)
try:
await asyncio.Event().wait()
finally:
await browser.async_cancel()
await aiozc.async_close()
def cli_entry_point() -> None:
"""Run the CLI."""
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
if __name__ == "__main__":
cli_entry_point()
sys.exit(0)

View File

@ -58,6 +58,7 @@ async def _async_zeroconf_get_service_info(
timeout: float,
) -> AsyncServiceInfo:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
had_instance = zeroconf_manager.has_instance
try:
zc = zeroconf_manager.get_async_zeroconf().zeroconf
except Exception as exc:
@ -73,7 +74,8 @@ async def _async_zeroconf_get_service_info(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
await zeroconf_manager.async_close()
if not had_instance:
await zeroconf_manager.async_close()
return info

View File

@ -39,7 +39,10 @@ async def main(argv: list[str]) -> None:
time_ = datetime.now()
message: bytes = msg.message
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
nanoseconds = time_.microsecond // 1000
print(
f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}"
)
stop = await async_run(cli, on_log)
try:

View File

@ -19,7 +19,7 @@ from .core import (
RequiresEncryptionAPIError,
UnhandledAPIConnectionError,
)
from .util import address_is_local
from .util import address_is_local, host_is_name_part
from .zeroconf import ZeroconfInstanceType
_LOGGER = logging.getLogger(__name__)
@ -79,7 +79,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self.name: str | None = None
if name:
self.name = name
elif address_is_local(client.address):
elif host_is_name_part(client.address) or address_is_local(client.address):
self.name = client.address.partition(".")[0]
if self.name:
self._cli.set_cached_name_if_unset(self.name)
@ -93,7 +93,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._a_name: str | None = None
# Flag to check if the device is connected
self._connection_state = ReconnectLogicState.DISCONNECTED
self._accept_zeroconf_records = True
self._accept_zeroconf_records: bool = True
self._connected_lock = asyncio.Lock()
self._is_stopped = True
self._zc_listening = False
@ -226,11 +226,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
def _schedule_connect(self, delay: float) -> None:
"""Schedule a connect attempt."""
self._cancel_connect("Scheduling new connect attempt")
if not delay:
self._call_connect_once()
return
_LOGGER.debug("Scheduling new connect attempt in %f seconds", delay)
_LOGGER.debug("Scheduling new connect attempt in %.2f seconds", delay)
self._cancel_connect_timer()
self._connect_timer = self.loop.call_at(
self.loop.time() + delay, self._call_connect_once
)
@ -240,17 +240,22 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
Must only be called from _schedule_connect.
"""
if self._connect_task:
if self._connect_task and not self._connect_task.done():
if self._connection_state != ReconnectLogicState.CONNECTING:
# Connection state is far enough along that we should
# not restart the connect task
_LOGGER.debug(
"%s: Not cancelling existing connect task as its already %s!",
self._cli.log_name,
self._connection_state,
)
return
_LOGGER.debug(
"%s: Cancelling existing connect task, to try again now!",
"%s: Cancelling existing connect task with state %s, to try again now!",
self._cli.log_name,
self._connection_state,
)
self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None
self._cancel_connect_task("Scheduling new connect attempt")
self._async_set_connection_state_without_lock(
ReconnectLogicState.DISCONNECTED
)
@ -260,15 +265,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
name=f"{self._cli.log_name}: aioesphomeapi connect",
)
def _cancel_connect(self, msg: str) -> None:
"""Cancel the connect."""
def _cancel_connect_timer(self) -> None:
"""Cancel the connect timer."""
if self._connect_timer:
self._connect_timer.cancel()
self._connect_timer = None
def _cancel_connect_task(self, msg: str) -> None:
"""Cancel the connect task."""
if self._connect_task:
self._connect_task.cancel(msg)
self._connect_task = None
def _cancel_connect(self, msg: str) -> None:
"""Cancel the connect."""
self._cancel_connect_timer()
self._cancel_connect_task(msg)
async def _connect_once_or_reschedule(self) -> None:
"""Connect once or schedule connect.
@ -290,7 +303,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.info(
"Trying to connect to %s in the background", self._cli.log_name
)
_LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time)
_LOGGER.debug("Retrying %s in %.2f seconds", self._cli.log_name, wait_time)
if wait_time:
# If we are waiting, start listening for mDNS records
self._start_zc_listen()
@ -365,6 +378,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
)
self._zc_listening = False
def _connect_from_zeroconf(self) -> None:
"""Connect from zeroconf."""
self._stop_zc_listen()
self._schedule_connect(0.0)
def async_update_records(
self,
zc: zeroconf.Zeroconf, # pylint: disable=unused-argument
@ -398,7 +416,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# We can't stop the zeroconf listener here because we are in the middle of
# a zeroconf callback which is iterating the listeners.
#
# So we schedule a stop for the next event loop iteration.
self.loop.call_soon(self._stop_zc_listen)
self._schedule_connect(0.0)
# So we schedule a stop for the next event loop iteration as well as the
# connect attempt.
#
# If we scheduled the connect attempt immediately, the listener could fire
# again before the connect attempt and we cancel and reschedule the connect
# attempt again.
#
self.loop.call_soon(self._connect_from_zeroconf)
self._accept_zeroconf_records = False
return

View File

@ -26,6 +26,11 @@ class ZeroconfManager:
if zeroconf is not None:
self.set_instance(zeroconf)
@property
def has_instance(self) -> bool:
"""Return True if a Zeroconf instance is set."""
return self._aiozc is not None
def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
"""Set the AsyncZeroconf instance."""
if self._aiozc:

View File

@ -30,5 +30,10 @@ disable = [
"too-many-lines",
]
[tool.ruff]
ignore = [
"E721", # We want type() check for protobuf messages
]
[build-system]
requires = ['setuptools>=65.4.1', 'wheel', 'Cython>=3.0.2']

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3
from subprocess import check_call
from pathlib import Path
import os
from pathlib import Path
from subprocess import check_call
root_dir = Path(__file__).absolute().parent.parent
os.chdir(root_dir)

View File

@ -11,7 +11,7 @@ with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file:
long_description = readme_file.read()
VERSION = "19.0.2"
VERSION = "19.1.0"
PROJECT_NAME = "aioesphomeapi"
PROJECT_PACKAGE_NAME = "aioesphomeapi"
PROJECT_LICENSE = "MIT"
@ -23,11 +23,11 @@ PROJECT_EMAIL = "esphome@nabucasa.com"
PROJECT_GITHUB_USERNAME = "esphome"
PROJECT_GITHUB_REPOSITORY = "aioesphomeapi"
PYPI_URL = "https://pypi.python.org/pypi/{}".format(PROJECT_PACKAGE_NAME)
GITHUB_PATH = "{}/{}".format(PROJECT_GITHUB_USERNAME, PROJECT_GITHUB_REPOSITORY)
GITHUB_URL = "https://github.com/{}".format(GITHUB_PATH)
PYPI_URL = f"https://pypi.python.org/pypi/{PROJECT_PACKAGE_NAME}"
GITHUB_PATH = f"{PROJECT_GITHUB_USERNAME}/{PROJECT_GITHUB_REPOSITORY}"
GITHUB_URL = f"https://github.com/{GITHUB_PATH}"
DOWNLOAD_URL = "{}/archive/{}.zip".format(GITHUB_URL, VERSION)
DOWNLOAD_URL = f"{GITHUB_URL}/archive/{VERSION}.zip"
MODULES_TO_CYTHONIZE = [
"aioesphomeapi/client_callbacks.py",
@ -61,7 +61,8 @@ setup_kwargs = {
"test_suite": "tests",
"entry_points": {
"console_scripts": [
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point"
"aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point",
"aioesphomeapi-discover=aioesphomeapi.discover:cli_entry_point",
],
},
}

View File

@ -49,6 +49,19 @@ def socket_socket():
yield func
@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
return cli
def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import logging
from functools import partial
from ipaddress import ip_address
from unittest.mock import AsyncMock, MagicMock, patch
@ -17,10 +18,14 @@ from zeroconf import (
from zeroconf.asyncio import AsyncZeroconf
from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
from aioesphomeapi import APIConnectionError
from aioesphomeapi import APIConnectionError, RequiresEncryptionAPIError
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from aioesphomeapi.reconnect_logic import (
MAXIMUM_BACKOFF_TRIES,
ReconnectLogic,
ReconnectLogicState,
)
from .common import (
get_mock_async_zeroconf,
@ -28,10 +33,20 @@ from .common import (
send_plaintext_connect_response,
send_plaintext_hello,
)
from .conftest import _create_mock_transport_protocol
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
async def quick_connect_fail(*args, **kwargs):
raise APIConnectionError
@pytest.mark.asyncio
async def test_reconnect_logic_name_from_host():
"""Test that the name is set correctly from the host."""
@ -71,13 +86,14 @@ async def test_reconnect_logic_name_from_host_and_set():
async def on_connect() -> None:
pass
ReconnectLogic(
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(),
name="mydevice",
)
assert rl.name == "mydevice"
assert cli.log_name == "mydevice.local"
@ -131,20 +147,38 @@ async def test_reconnect_logic_name_from_name():
@pytest.mark.asyncio
async def test_reconnect_logic_state():
async def test_reconnect_logic_name_from_cli_address():
"""Test that the name is set correctly from the address."""
cli = APIClient(
address="mydevice",
port=6052,
password=None,
)
async def on_disconnect(expected_disconnect: bool) -> None:
pass
async def on_connect() -> None:
pass
rl = ReconnectLogic(
client=cli,
on_disconnect=on_disconnect,
on_connect=on_connect,
zeroconf_instance=get_mock_zeroconf(),
)
assert cli.log_name == "mydevice"
assert rl.name == "mydevice"
@pytest.mark.asyncio
async def test_reconnect_logic_state(patchable_api_client: APIClient):
"""Test that reconnect logic state changes."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
@ -178,9 +212,10 @@ async def test_reconnect_logic_state():
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._tries == 1
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=APIConnectionError
cli, "finish_connection", side_effect=RequiresEncryptionAPIError
):
await rl.start()
await asyncio.sleep(0)
@ -189,8 +224,9 @@ async def test_reconnect_logic_state():
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 2
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert isinstance(on_connect_fail_called[-1], RequiresEncryptionAPIError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._tries == MAXIMUM_BACKOFF_TRIES
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
await rl.start()
@ -201,26 +237,20 @@ async def test_reconnect_logic_state():
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 2
assert rl._connection_state is ReconnectLogicState.READY
assert rl._tries == 0
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_retry():
async def test_reconnect_retry(
patchable_api_client: APIClient, caplog: pytest.LogCaptureFixture
):
"""Test that reconnect logic retry."""
on_disconnect_called = []
on_connect_called = []
on_connect_fail_called = []
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
async def on_disconnect(expected_disconnect: bool) -> None:
nonlocal on_disconnect_called
@ -243,6 +273,7 @@ async def test_reconnect_retry():
on_connect_error=on_connect_fail,
)
assert cli.log_name == "mydevice @ 1.2.3.4"
caplog.clear()
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
await rl.start()
@ -255,35 +286,70 @@ async def test_reconnect_retry():
assert len(on_connect_fail_called) == 1
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert "connect to ESPHome API for mydevice @ 1.2.3.4" in caplog.text
for record in caplog.records:
if "connect to ESPHome API for mydevice @ 1.2.3.4" in record.message:
assert record.levelno == logging.WARNING
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
caplog.clear()
# Next retry should run at debug level
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
# Should now retry
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 0
assert len(on_connect_fail_called) == 2
assert isinstance(on_connect_fail_called[-1], APIConnectionError)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert "connect to ESPHome API for mydevice @ 1.2.3.4" in caplog.text
for record in caplog.records:
if "connect to ESPHome API for mydevice @ 1.2.3.4" in record.message:
assert record.levelno == logging.DEBUG
caplog.clear()
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
# Should now retry
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert "connect to ESPHome API for mydevice @ 1.2.3.4" not in caplog.text
assert len(on_disconnect_called) == 0
assert len(on_connect_called) == 1
assert len(on_connect_fail_called) == 1
assert len(on_connect_fail_called) == 2
assert rl._connection_state is ReconnectLogicState.READY
original_when = rl._connect_timer.when()
# Ensure starting the connection logic again does not trigger a new connection
await rl.start()
# Verify no new timer is started
assert rl._connect_timer.when() == original_when
await rl.stop()
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
DNS_POINTER = DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"mydevice._esphomelib._tcp.local.",
)
@pytest.mark.parametrize(
("record", "should_trigger_zeroconf", "log_text"),
("record", "should_trigger_zeroconf", "expected_state_after_trigger", "log_text"),
(
(
DNSPointer(
"_esphomelib._tcp.local.",
_TYPE_PTR,
_CLASS_IN,
1000,
"mydevice._esphomelib._tcp.local.",
),
DNS_POINTER,
True,
ReconnectLogicState.READY,
"received mDNS record",
),
(
@ -295,6 +361,7 @@ async def test_reconnect_retry():
"wrong_name._esphomelib._tcp.local.",
),
False,
ReconnectLogicState.CONNECTING,
"",
),
(
@ -306,27 +373,23 @@ async def test_reconnect_retry():
ip_address("1.2.3.4").packed,
),
True,
ReconnectLogicState.READY,
"received mDNS record",
),
),
)
@pytest.mark.asyncio
async def test_reconnect_zeroconf(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
record: DNSRecord,
should_trigger_zeroconf: bool,
expected_state_after_trigger: ReconnectLogicState,
log_text: str,
) -> None:
"""Test that reconnect logic retry."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
mock_zeroconf = MagicMock(spec=Zeroconf)
@ -340,13 +403,6 @@ async def test_reconnect_zeroconf(
)
assert cli.log_name == "mydevice @ 1.2.3.4"
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
async def quick_connect_fail(*args, **kwargs):
raise APIConnectionError
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
@ -358,30 +414,203 @@ async def test_reconnect_zeroconf(
with patch.object(
cli, "start_connection", side_effect=slow_connect_fail
) as mock_start_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.CONNECTING
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert mock_start_connection.call_count == 0
caplog.clear()
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection"
):
assert rl._zc_listening is True
rl.async_update_records(
mock_zeroconf, current_time_millis(), [RecordUpdate(record, None)]
)
assert (
"Triggering connect because of received mDNS record" in caplog.text
) is should_trigger_zeroconf
assert rl._accept_zeroconf_records is not should_trigger_zeroconf
assert rl._zc_listening is True # should change after one iteration of the loop
await asyncio.sleep(0)
assert rl._zc_listening is not should_trigger_zeroconf
# The reconnect is scheduled to run in the next loop iteration
await asyncio.sleep(0)
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
assert log_text in caplog.text
assert rl._connection_state is expected_state_after_trigger
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback():
"""Test that the stop_callback stops the ReconnectLogic."""
cli = APIClient(
address="1.2.3.4",
port=6052,
password=None,
async def test_reconnect_zeroconf_not_while_handshaking(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that reconnect logic retry will not trigger a zeroconf reconnect while handshaking."""
cli = patchable_api_client
mock_zeroconf = MagicMock(spec=Zeroconf)
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
zeroconf_instance=mock_zeroconf,
name="mydevice",
on_connect_error=AsyncMock(),
)
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
) as mock_finish_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert mock_finish_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
assert rl._accept_zeroconf_records is False
assert not rl._is_stopped
rl.async_update_records(
mock_zeroconf, current_time_millis(), [RecordUpdate(DNS_POINTER, None)]
)
assert (
"Triggering connect because of received mDNS record" in caplog.text
) is False
rl._cancel_connect("forced cancel in test")
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_connect_task_not_cancelled_while_handshaking(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that reconnect logic will not cancel an in progress handshake."""
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
name="mydevice",
on_connect_error=AsyncMock(),
)
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(cli, "start_connection") as mock_start_connection, patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
) as mock_finish_connection:
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
assert rl._accept_zeroconf_records is True
assert not rl._is_stopped
assert rl._connect_timer is not None
rl._connect_timer._run()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
assert mock_finish_connection.call_count == 1
assert rl._connection_state is ReconnectLogicState.HANDSHAKING
assert rl._accept_zeroconf_records is False
assert not rl._is_stopped
caplog.clear()
# This can likely never happen in practice, but we should handle it
# in the event there is a race as the consequence is that we could
# disconnect a working connection.
rl._call_connect_once()
assert (
"Not cancelling existing connect task as its already ReconnectLogicState.HANDSHAKING"
in caplog.text
)
rl._cancel_connect("forced cancel in test")
await rl.stop()
assert rl._is_stopped is True
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_connect_aborts_if_stopped(
patchable_api_client: APIClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that reconnect logic will abort connecting if stopped."""
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
on_connect=AsyncMock(),
name="mydevice",
on_connect_error=AsyncMock(),
)
assert cli.log_name == "mydevice @ 1.2.3.4"
with patch.object(
cli, "start_connection", side_effect=quick_connect_fail
) as mock_start_connection:
await rl.start()
await asyncio.sleep(0)
assert mock_start_connection.call_count == 1
with patch.object(cli, "start_connection") as mock_start_connection:
timer = rl._connect_timer
assert timer is not None
await rl.stop()
assert rl._is_stopped is True
rl._call_connect_once()
await asyncio.sleep(0)
await asyncio.sleep(0)
# We should never try to connect again
# once we are stopped
assert mock_start_connection.call_count == 0
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback(patchable_api_client: APIClient):
"""Test that the stop_callback stops the ReconnectLogic."""
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
@ -403,17 +632,11 @@ async def test_reconnect_logic_stop_callback():
@pytest.mark.asyncio
async def test_reconnect_logic_stop_callback_waits_for_handshake():
async def test_reconnect_logic_stop_callback_waits_for_handshake(
patchable_api_client: APIClient,
):
"""Test that the stop_callback waits for a handshake."""
class PatchableAPIClient(APIClient):
pass
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
)
cli = patchable_api_client
rl = ReconnectLogic(
client=cli,
on_disconnect=AsyncMock(),
@ -423,10 +646,6 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake():
)
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
async def slow_connect_fail(*args, **kwargs):
await asyncio.sleep(10)
raise APIConnectionError
with patch.object(cli, "start_connection"), patch.object(
cli, "finish_connection", side_effect=slow_connect_fail
):
@ -473,13 +692,6 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
zeroconf_instance=async_zeroconf.zeroconf,
)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
connected = asyncio.Event()
on_disconnect_calls = []
@ -498,20 +710,23 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connected.wait()
assert cli._connection.is_connected is True
await asyncio.sleep(0)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
) as mock_create_connection:
protocol.eof_received()
# Wait for the task to run