Add support for creating eager tasks to reduce connect latency (#921)

This commit is contained in:
J. Nick Koston 2024-08-12 15:04:18 -05:00 committed by GitHub
parent 735a083605
commit bf15d8e1fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 13 deletions

View File

@ -139,7 +139,7 @@ from .model_conversions import (
LIST_ENTITIES_SERVICES_RESPONSE_TYPES,
SUBSCRIBE_STATES_RESPONSE_TYPES,
)
from .util import build_log_name
from .util import build_log_name, create_eager_task
from .zeroconf import ZeroconfInstanceType, ZeroconfManager
_LOGGER = logging.getLogger(__name__)
@ -1311,7 +1311,7 @@ class APIClient:
wake_word_phrase: str | None = command.wake_word_phrase
if wake_word_phrase == "":
wake_word_phrase = None
start_task = asyncio.create_task(
start_task = create_eager_task(
handle_start(
command.conversation_id,
command.flags,
@ -1370,7 +1370,7 @@ class APIClient:
def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None:
"""Create a background task and add it to the background tasks set."""
task = asyncio.create_task(coro)
task = create_eager_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)

View File

@ -18,7 +18,7 @@ from .core import (
RequiresEncryptionAPIError,
UnhandledAPIConnectionError,
)
from .util import address_is_local, host_is_name_part
from .util import address_is_local, create_eager_task, host_is_name_part
from .zeroconf import ZeroconfInstanceType
_LOGGER = logging.getLogger(__name__)
@ -259,7 +259,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
ReconnectLogicState.DISCONNECTED
)
self._connect_task = asyncio.create_task(
self._connect_task = create_eager_task(
self._connect_once_or_reschedule(),
name=f"{self._cli.log_name}: aioesphomeapi connect",
)
@ -318,7 +318,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
def stop_callback(self) -> None:
"""Stop the connect logic."""
self._stop_task = asyncio.create_task(
self._stop_task = create_eager_task(
self.stop(),
name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
)

View File

@ -1,6 +1,12 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task, get_running_loop
from collections.abc import Coroutine
import math
import sys
from typing import Any, TypeVar
_T = TypeVar("_T")
def fix_float_single_double_conversion(value: float) -> float:
@ -55,3 +61,31 @@ def build_log_name(
):
return f"{name} @ {preferred_address}"
return preferred_address
if sys.version_info >= (3, 12, 0):
def create_eager_task(
coro: Coroutine[Any, Any, _T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
eager_start=True, # type: ignore[call-arg]
)
else:
def create_eager_task(
coro: Coroutine[Any, Any, _T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine."""
return Task(coro, loop=loop or get_running_loop(), name=name)

View File

@ -1,9 +1,9 @@
from __future__ import annotations
import asyncio
import logging
from functools import partial
from ipaddress import ip_address
import logging
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -429,8 +429,14 @@ async def test_reconnect_zeroconf(
assert not rl._is_stopped
caplog.clear()
async def delayed_connect(*args, **kwargs):
await asyncio.sleep(0)
with (
patch.object(cli, "start_connection") as mock_start_connection,
patch.object(
cli, "start_connection", side_effect=delayed_connect
) as mock_start_connection,
patch.object(cli, "finish_connection"),
):
assert rl._zc_listening is True
@ -445,8 +451,11 @@ async def test_reconnect_zeroconf(
# The reconnect is scheduled to run in the next loop iteration
await asyncio.sleep(0)
await asyncio.sleep(0)
assert mock_start_connection.call_count == int(should_trigger_zeroconf)
assert log_text in caplog.text
await asyncio.sleep(0)
assert rl._connection_state is expected_state_after_trigger
await rl.stop()
@ -733,11 +742,15 @@ async def test_handling_unexpected_disconnect(aiohappyeyeballs_start_connection)
assert cli._connection.is_connected is True
await asyncio.sleep(0)
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
) as mock_create_connection:
with (
patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
) as mock_create_connection,
patch.object(cli, "start_connection"),
patch.object(cli, "finish_connection"),
):
protocol.eof_received()
# Wait for the task to run
await asyncio.sleep(0)

View File

@ -1,4 +1,6 @@
import asyncio
import math
import sys
import pytest
@ -29,3 +31,53 @@ def test_fix_float_single_double_conversion(input, output):
def test_fix_float_single_double_conversion_nan():
assert math.isnan(util.fix_float_single_double_conversion(float("nan")))
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+")
async def test_create_eager_task_312() -> None:
"""Test create_eager_task schedules a task eagerly in the event loop.
For Python 3.12+, the task is scheduled eagerly in the event loop.
"""
events = []
async def _normal_task():
events.append("normal")
async def _eager_task():
events.append("eager")
task1 = util.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())
assert events == ["eager"]
await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12")
async def test_create_eager_task_pre_312() -> None:
"""Test create_eager_task schedules a task in the event loop.
For older python versions, the task is scheduled normally.
"""
events = []
async def _normal_task():
events.append("normal")
async def _eager_task():
events.append("eager")
task1 = util.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())
assert events == []
await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2