diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index f90a7e1..77506da 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -6,7 +6,18 @@ import socket import time from contextlib import suppress from dataclasses import astuple, dataclass -from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Set, + Type, + Union, +) import async_timeout from google.protobuf import message @@ -127,7 +138,7 @@ class APIConnection: "_connect_complete", "_message_handlers", "log_name", - "_read_exception_handlers", + "_read_exception_futures", "_ping_timer", "_pong_timer", "_keep_alive_interval", @@ -163,8 +174,8 @@ class APIConnection: # The friendly name to show for this connection in the logs self.log_name = log_name or params.address - # Handlers currently subscribed to exceptions in the read task - self._read_exception_handlers: List[Callable[[Exception], None]] = [] + # futures currently subscribed to exceptions in the read task + self._read_exception_futures: Set[asyncio.Future[None]] = set() self._ping_timer: Optional[asyncio.TimerHandle] = None self._pong_timer: Optional[asyncio.TimerHandle] = None @@ -601,7 +612,7 @@ class APIConnection: :raises TimeoutAPIError: if a timeout occured """ - fut = self._loop.create_future() + fut: asyncio.Future[None] = self._loop.create_future() responses = [] def on_message(resp: message.Message) -> None: @@ -610,19 +621,12 @@ class APIConnection: if do_append(resp): responses.append(resp) if do_stop(resp): - fut.set_result(responses) - - def on_read_exception(exc: Exception) -> None: - if not fut.done(): - new_exc = exc - if not isinstance(exc, APIConnectionError): - new_exc = ReadFailedAPIError("Read failed") - new_exc.__cause__ = exc - fut.set_exception(new_exc) + fut.set_result(None) for msg_type in msg_types: self._message_handlers.setdefault(msg_type, []).append(on_message) - self._read_exception_handlers.append(on_read_exception) + + self._read_exception_futures.add(fut) # We must not await without a finally or # the message could fail to be removed if the # the await is cancelled @@ -639,8 +643,7 @@ class APIConnection: for msg_type in msg_types: with suppress(ValueError): self._message_handlers[msg_type].remove(on_message) - with suppress(ValueError): - self._read_exception_handlers.remove(on_read_exception) + self._read_exception_futures.discard(fut) return responses @@ -678,9 +681,15 @@ class APIConnection: ) self._fatal_exception = err self._connection_state = ConnectionState.CLOSED - for handler in self._read_exception_handlers[:]: - handler(err) - self._read_exception_handlers.clear() + for fut in self._read_exception_futures: + if fut.done(): + continue + new_exc = err + if not isinstance(err, APIConnectionError): + new_exc = ReadFailedAPIError("Read failed") + new_exc.__cause__ = err + fut.set_exception(new_exc) + self._read_exception_futures.clear() self._cleanup() def _process_packet(self, msg_type_proto: int, data: bytes) -> None: