Fix race in cleaning up connection (#698)
This commit is contained in:
parent
c0a153c9f3
commit
e01f22d99a
|
@ -174,6 +174,7 @@ class APIClient:
|
|||
"cached_name",
|
||||
"_background_tasks",
|
||||
"_loop",
|
||||
"_on_stop_task",
|
||||
"log_name",
|
||||
)
|
||||
|
||||
|
@ -219,6 +220,7 @@ class APIClient:
|
|||
self.cached_name: str | None = None
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._on_stop_task: asyncio.Task[None] | None = None
|
||||
self._set_log_name()
|
||||
|
||||
@property
|
||||
|
@ -258,13 +260,36 @@ class APIClient:
|
|||
|
||||
async def connect(
|
||||
self,
|
||||
on_stop: Callable[[bool], Awaitable[None]] | None = None,
|
||||
on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = None,
|
||||
login: bool = False,
|
||||
) -> None:
|
||||
"""Connect to the device."""
|
||||
await self.start_connection(on_stop)
|
||||
await self.finish_connection(login)
|
||||
|
||||
def _on_stop(
|
||||
self,
|
||||
on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None,
|
||||
expected_disconnect: bool,
|
||||
) -> None:
|
||||
# Hook into on_stop handler to clear connection when stopped
|
||||
self._connection = None
|
||||
if on_stop:
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
on_stop(expected_disconnect),
|
||||
name=f"{self.log_name} aioesphomeapi on_stop",
|
||||
)
|
||||
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
|
||||
|
||||
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task.
|
||||
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._on_stop_task = None
|
||||
|
||||
async def start_connection(
|
||||
self,
|
||||
on_stop: Callable[[bool], Awaitable[None]] | None = None,
|
||||
|
@ -273,13 +298,9 @@ class APIClient:
|
|||
if self._connection is not None:
|
||||
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
||||
|
||||
async def _on_stop(expected_disconnect: bool) -> None:
|
||||
# Hook into on_stop handler to clear connection when stopped
|
||||
self._connection = None
|
||||
if on_stop is not None:
|
||||
await on_stop(expected_disconnect)
|
||||
|
||||
self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name)
|
||||
self._connection = APIConnection(
|
||||
self._params, partial(self._on_stop, on_stop), log_name=self.log_name
|
||||
)
|
||||
|
||||
try:
|
||||
await self._connection.start_connection()
|
||||
|
|
|
@ -60,7 +60,6 @@ cdef class APIConnection:
|
|||
|
||||
cdef ConnectionParams _params
|
||||
cdef public object on_stop
|
||||
cdef object _on_stop_task
|
||||
cdef public object _socket
|
||||
cdef public APIFrameHelper _frame_helper
|
||||
cdef public object api_version
|
||||
|
|
|
@ -11,7 +11,6 @@ import time
|
|||
# instead of the one from asyncio since they are the same in Python 3.11+
|
||||
from asyncio import CancelledError
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
@ -134,7 +133,6 @@ class APIConnection:
|
|||
__slots__ = (
|
||||
"_params",
|
||||
"on_stop",
|
||||
"_on_stop_task",
|
||||
"_socket",
|
||||
"_frame_helper",
|
||||
"api_version",
|
||||
|
@ -162,12 +160,11 @@ class APIConnection:
|
|||
def __init__(
|
||||
self,
|
||||
params: ConnectionParams,
|
||||
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
|
||||
on_stop: Callable[[bool], None],
|
||||
log_name: str | None = None,
|
||||
) -> None:
|
||||
self._params = params
|
||||
self.on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = on_stop
|
||||
self._on_stop_task: asyncio.Task[None] | None = None
|
||||
self.on_stop: Callable[[bool], None] | None = on_stop
|
||||
self._socket: socket.socket | None = None
|
||||
self._frame_helper: None | (
|
||||
APINoiseFrameHelper | APIPlaintextFrameHelper
|
||||
|
@ -261,23 +258,9 @@ class APIConnection:
|
|||
self._ping_timer.cancel()
|
||||
self._ping_timer = None
|
||||
|
||||
if self.on_stop is not None and was_connected:
|
||||
# Ensure on_stop is called only once
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
self.on_stop(self._expected_disconnect),
|
||||
name=f"{self.log_name} aioesphomeapi connection on_stop",
|
||||
)
|
||||
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
|
||||
if (on_stop := self.on_stop) is not None and was_connected:
|
||||
self.on_stop = None
|
||||
|
||||
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task.
|
||||
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._on_stop_task = None
|
||||
on_stop(self._expected_disconnect)
|
||||
|
||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||
"""Step 1 in connect process: resolve the address."""
|
||||
|
|
|
@ -4,12 +4,14 @@ import asyncio
|
|||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from typing import Awaitable, Callable
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from google.protobuf import message
|
||||
from zeroconf import Zeroconf
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from aioesphomeapi import APIClient
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
|
@ -117,6 +119,16 @@ async def connect(conn: APIConnection, login: bool = True):
|
|||
await conn.finish_connection(login=login)
|
||||
|
||||
|
||||
async def connect_client(
|
||||
client: APIClient,
|
||||
login: bool = True,
|
||||
on_stop: Callable[[bool], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
await client.start_connection(on_stop=on_stop)
|
||||
await client.finish_connection(login=login)
|
||||
|
||||
|
||||
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
|
|
|
@ -24,13 +24,13 @@ from aioesphomeapi.core import (
|
|||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
SocketAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
async_fire_time_changed,
|
||||
connect,
|
||||
connect_client,
|
||||
generate_plaintext_packet,
|
||||
get_mock_protocol,
|
||||
mock_data_received,
|
||||
|
@ -515,8 +515,6 @@ async def test_disconnect_fails_to_send_response(
|
|||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
conn = APIConnection(connection_params, _on_stop)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
|
@ -527,10 +525,11 @@ async def test_disconnect_fails_to_send_response(
|
|||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
connect_task = asyncio.create_task(
|
||||
connect_client(client, login=False, on_stop=_on_stop)
|
||||
)
|
||||
await connected.wait()
|
||||
send_plaintext_hello(protocol)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
|
||||
|
@ -538,7 +537,7 @@ async def test_disconnect_fails_to_send_response(
|
|||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
assert client._connection.is_connected
|
||||
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
disconnect_request = DisconnectRequest()
|
||||
|
@ -571,8 +570,6 @@ async def test_disconnect_success_case(
|
|||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
conn = APIConnection(connection_params, _on_stop)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
|
@ -583,10 +580,11 @@ async def test_disconnect_success_case(
|
|||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
connect_task = asyncio.create_task(
|
||||
connect_client(client, login=False, on_stop=_on_stop)
|
||||
)
|
||||
await connected.wait()
|
||||
send_plaintext_hello(protocol)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
|
||||
|
@ -594,7 +592,7 @@ async def test_disconnect_success_case(
|
|||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
assert client._connection.is_connected
|
||||
|
||||
disconnect_request = DisconnectRequest()
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||
|
@ -602,7 +600,7 @@ async def test_disconnect_success_case(
|
|||
# Wait one loop iteration for the disconnect to be processed
|
||||
await asyncio.sleep(0)
|
||||
assert expected_disconnect is True
|
||||
assert not conn.is_connected
|
||||
assert not client._connection
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -20,7 +20,6 @@ from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
|
|||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
from .common import (
|
||||
|
|
Loading…
Reference in New Issue