Fix race in cleaning up connection (#698)

This commit is contained in:
J. Nick Koston 2023-11-25 07:11:34 -06:00 committed by GitHub
parent c0a153c9f3
commit e01f22d99a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 43 deletions

View File

@ -174,6 +174,7 @@ class APIClient:
"cached_name",
"_background_tasks",
"_loop",
"_on_stop_task",
"log_name",
)
@ -219,6 +220,7 @@ class APIClient:
self.cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
self._on_stop_task: asyncio.Task[None] | None = None
self._set_log_name()
@property
@ -258,13 +260,36 @@ class APIClient:
async def connect(
self,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = None,
login: bool = False,
) -> None:
"""Connect to the device."""
await self.start_connection(on_stop)
await self.finish_connection(login)
def _on_stop(
self,
on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None,
expected_disconnect: bool,
) -> None:
# Hook into on_stop handler to clear connection when stopped
self._connection = None
if on_stop:
self._on_stop_task = asyncio.create_task(
on_stop(expected_disconnect),
name=f"{self.log_name} aioesphomeapi on_stop",
)
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
"""Remove the stop task.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._on_stop_task = None
async def start_connection(
self,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
@ -273,13 +298,9 @@ class APIClient:
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self.log_name}!")
async def _on_stop(expected_disconnect: bool) -> None:
# Hook into on_stop handler to clear connection when stopped
self._connection = None
if on_stop is not None:
await on_stop(expected_disconnect)
self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name)
self._connection = APIConnection(
self._params, partial(self._on_stop, on_stop), log_name=self.log_name
)
try:
await self._connection.start_connection()

View File

@ -60,7 +60,6 @@ cdef class APIConnection:
cdef ConnectionParams _params
cdef public object on_stop
cdef object _on_stop_task
cdef public object _socket
cdef public APIFrameHelper _frame_helper
cdef public object api_version

View File

@ -11,7 +11,6 @@ import time
# instead of the one from asyncio since they are the same in Python 3.11+
from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError
from collections.abc import Coroutine
from dataclasses import astuple, dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
@ -134,7 +133,6 @@ class APIConnection:
__slots__ = (
"_params",
"on_stop",
"_on_stop_task",
"_socket",
"_frame_helper",
"api_version",
@ -162,12 +160,11 @@ class APIConnection:
def __init__(
self,
params: ConnectionParams,
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
on_stop: Callable[[bool], None],
log_name: str | None = None,
) -> None:
self._params = params
self.on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = on_stop
self._on_stop_task: asyncio.Task[None] | None = None
self.on_stop: Callable[[bool], None] | None = on_stop
self._socket: socket.socket | None = None
self._frame_helper: None | (
APINoiseFrameHelper | APIPlaintextFrameHelper
@ -261,23 +258,9 @@ class APIConnection:
self._ping_timer.cancel()
self._ping_timer = None
if self.on_stop is not None and was_connected:
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
name=f"{self.log_name} aioesphomeapi connection on_stop",
)
self._on_stop_task.add_done_callback(self._remove_on_stop_task)
if (on_stop := self.on_stop) is not None and was_connected:
self.on_stop = None
def _remove_on_stop_task(self, _fut: asyncio.Future[None]) -> None:
"""Remove the stop task.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._on_stop_task = None
on_stop(self._expected_disconnect)
async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address."""

View File

@ -4,12 +4,14 @@ import asyncio
import time
from datetime import datetime, timezone
from functools import partial
from typing import Awaitable, Callable
from unittest.mock import AsyncMock, MagicMock, patch
from google.protobuf import message
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
@ -117,6 +119,16 @@ async def connect(conn: APIConnection, login: bool = True):
await conn.finish_connection(login=login)
async def connect_client(
client: APIClient,
login: bool = True,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
) -> None:
"""Wrapper for connection logic to do both parts."""
await client.start_connection(on_stop=on_stop)
await client.finish_connection(login=login)
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1

View File

@ -24,13 +24,13 @@ from aioesphomeapi.core import (
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
SocketAPIError,
TimeoutAPIError,
)
from .common import (
async_fire_time_changed,
connect,
connect_client,
generate_plaintext_packet,
get_mock_protocol,
mock_data_received,
@ -515,8 +515,6 @@ async def test_disconnect_fails_to_send_response(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
conn = APIConnection(connection_params, _on_stop)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
@ -527,10 +525,11 @@ async def test_disconnect_fails_to_send_response(
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop)
)
await connected.wait()
send_plaintext_hello(protocol)
client._connection = conn
await connect_task
transport.reset_mock()
@ -538,7 +537,7 @@ async def test_disconnect_fails_to_send_response(
send_plaintext_connect_response(protocol, False)
await connect_task
assert conn.is_connected
assert client._connection.is_connected
with patch.object(protocol, "_writer", side_effect=OSError):
disconnect_request = DisconnectRequest()
@ -571,8 +570,6 @@ async def test_disconnect_success_case(
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
conn = APIConnection(connection_params, _on_stop)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
@ -583,10 +580,11 @@ async def test_disconnect_success_case(
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
connect_task = asyncio.create_task(
connect_client(client, login=False, on_stop=_on_stop)
)
await connected.wait()
send_plaintext_hello(protocol)
client._connection = conn
await connect_task
transport.reset_mock()
@ -594,7 +592,7 @@ async def test_disconnect_success_case(
send_plaintext_connect_response(protocol, False)
await connect_task
assert conn.is_connected
assert client._connection.is_connected
disconnect_request = DisconnectRequest()
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
@ -602,7 +600,7 @@ async def test_disconnect_success_case(
# Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0)
assert expected_disconnect is True
assert not conn.is_connected
assert not client._connection
@pytest.mark.asyncio

View File

@ -20,7 +20,6 @@ from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR
from aioesphomeapi import APIConnectionError
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.client import APIClient
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import (