mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-21 16:37:41 +01:00
Add support for creating eager tasks to reduce connect latency (#921)
This commit is contained in:
parent
735a083605
commit
bf15d8e1fb
@ -139,7 +139,7 @@ from .model_conversions import (
|
||||
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
|
||||
SUBSCRIBE_STATES_RESPONSE_TYPES,
|
||||
)
|
||||
from .util import build_log_name
|
||||
from .util import build_log_name, create_eager_task
|
||||
from .zeroconf import ZeroconfInstanceType, ZeroconfManager
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -1311,7 +1311,7 @@ class APIClient:
|
||||
wake_word_phrase: str | None = command.wake_word_phrase
|
||||
if wake_word_phrase == "":
|
||||
wake_word_phrase = None
|
||||
start_task = asyncio.create_task(
|
||||
start_task = create_eager_task(
|
||||
handle_start(
|
||||
command.conversation_id,
|
||||
command.flags,
|
||||
@ -1370,7 +1370,7 @@ class APIClient:
|
||||
|
||||
def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None:
|
||||
"""Create a background task and add it to the background tasks set."""
|
||||
task = asyncio.create_task(coro)
|
||||
task = create_eager_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .core import (
|
||||
RequiresEncryptionAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
)
|
||||
from .util import address_is_local, host_is_name_part
|
||||
from .util import address_is_local, create_eager_task, host_is_name_part
|
||||
from .zeroconf import ZeroconfInstanceType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -259,7 +259,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._connect_task = create_eager_task(
|
||||
self._connect_once_or_reschedule(),
|
||||
name=f"{self._cli.log_name}: aioesphomeapi connect",
|
||||
)
|
||||
@ -318,7 +318,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
def stop_callback(self) -> None:
|
||||
"""Stop the connect logic."""
|
||||
self._stop_task = asyncio.create_task(
|
||||
self._stop_task = create_eager_task(
|
||||
self.stop(),
|
||||
name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
|
||||
)
|
||||
|
@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import AbstractEventLoop, Task, get_running_loop
|
||||
from collections.abc import Coroutine
|
||||
import math
|
||||
import sys
|
||||
from typing import Any, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def fix_float_single_double_conversion(value: float) -> float:
|
||||
@ -55,3 +61,31 @@ def build_log_name(
|
||||
):
|
||||
return f"{name} @ {preferred_address}"
|
||||
return preferred_address
|
||||
|
||||
|
||||
if sys.version_info >= (3, 12, 0):
|
||||
|
||||
def create_eager_task(
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> Task[_T]:
|
||||
"""Create a task from a coroutine and schedule it to run immediately."""
|
||||
return Task(
|
||||
coro,
|
||||
loop=loop or get_running_loop(),
|
||||
name=name,
|
||||
eager_start=True, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def create_eager_task(
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> Task[_T]:
|
||||
"""Create a task from a coroutine."""
|
||||
return Task(coro, loop=loop or get_running_loop(), name=name)
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -429,8 +429,14 @@ async def test_reconnect_zeroconf(
|
||||
assert not rl._is_stopped
|
||||
|
||||
caplog.clear()
|
||||
|
||||
async def delayed_connect(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with (
|
||||
patch.object(cli, "start_connection") as mock_start_connection,
|
||||
patch.object(
|
||||
cli, "start_connection", side_effect=delayed_connect
|
||||
) as mock_start_connection,
|
||||
patch.object(cli, "finish_connection"),
|
||||
):
|
||||
assert rl._zc_listening is True
|
||||
@ -445,8 +451,11 @@ async def test_reconnect_zeroconf(
|
||||
|
||||
# The reconnect is scheduled to run in the next loop iteration
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
|
||||
assert log_text in caplog.text
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert rl._connection_state is expected_state_after_trigger
|
||||
await rl.stop()
|
||||
@ -733,11 +742,15 @@ async def test_handling_unexpected_disconnect(aiohappyeyeballs_start_connection)
|
||||
assert cli._connection.is_connected is True
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
) as mock_create_connection:
|
||||
with (
|
||||
patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
) as mock_create_connection,
|
||||
patch.object(cli, "start_connection"),
|
||||
patch.object(cli, "finish_connection"),
|
||||
):
|
||||
protocol.eof_received()
|
||||
# Wait for the task to run
|
||||
await asyncio.sleep(0)
|
||||
|
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import math
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
@ -29,3 +31,53 @@ def test_fix_float_single_double_conversion(input, output):
|
||||
|
||||
def test_fix_float_single_double_conversion_nan():
|
||||
assert math.isnan(util.fix_float_single_double_conversion(float("nan")))
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+")
|
||||
async def test_create_eager_task_312() -> None:
|
||||
"""Test create_eager_task schedules a task eagerly in the event loop.
|
||||
|
||||
For Python 3.12+, the task is scheduled eagerly in the event loop.
|
||||
"""
|
||||
events = []
|
||||
|
||||
async def _normal_task():
|
||||
events.append("normal")
|
||||
|
||||
async def _eager_task():
|
||||
events.append("eager")
|
||||
|
||||
task1 = util.create_eager_task(_eager_task())
|
||||
task2 = asyncio.create_task(_normal_task())
|
||||
|
||||
assert events == ["eager"]
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert events == ["eager", "normal"]
|
||||
await task1
|
||||
await task2
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12")
|
||||
async def test_create_eager_task_pre_312() -> None:
|
||||
"""Test create_eager_task schedules a task in the event loop.
|
||||
|
||||
For older python versions, the task is scheduled normally.
|
||||
"""
|
||||
events = []
|
||||
|
||||
async def _normal_task():
|
||||
events.append("normal")
|
||||
|
||||
async def _eager_task():
|
||||
events.append("eager")
|
||||
|
||||
task1 = util.create_eager_task(_eager_task())
|
||||
task2 = asyncio.create_task(_normal_task())
|
||||
|
||||
assert events == []
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert events == ["eager", "normal"]
|
||||
await task1
|
||||
await task2
|
||||
|
Loading…
Reference in New Issue
Block a user