Refactor cleanup to be a normal function (#355)

This commit is contained in:
J. Nick Koston 2023-01-06 16:42:39 -10:00 committed by GitHub
parent 2886d361f0
commit 284b767d8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 107 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import base64
import logging
from abc import abstractmethod, abstractproperty
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional, Union, cast
@ -49,15 +49,10 @@ class APIFrameHelper(asyncio.Protocol):
self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self.read_lock = asyncio.Lock()
self._closed_event = asyncio.Event()
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._pos = 0
@abstractproperty # pylint: disable=deprecated-decorator
def ready(self) -> bool:
"""Return if the connection is ready."""
def _init_read(self, length: int) -> Optional[bytearray]:
"""Start reading a packet from the buffer."""
self._pos = 0
@ -90,7 +85,6 @@ class APIFrameHelper(asyncio.Protocol):
self.close()
def _handle_error(self, exc: Exception) -> None:
self._closed_event.set()
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
@ -103,7 +97,6 @@ class APIFrameHelper(asyncio.Protocol):
def close(self) -> None:
"""Close the connection."""
self._closed_event.set()
if self._transport:
self._transport.close()
@ -111,11 +104,6 @@ class APIFrameHelper(asyncio.Protocol):
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
return self._connected_event.is_set()
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
"""Complete reading a packet from the buffer."""
del self._buffer[: self._pos]
@ -138,7 +126,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
try:
self._transport.write(data)
except (ConnectionResetError, OSError) as err:
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def perform_handshake(self) -> None:
@ -250,11 +238,6 @@ class APINoiseFrameHelper(APIFrameHelper):
self._state = NoiseConnectionState.HELLO
self._setup_proto()
@property
def ready(self) -> bool:
"""Return if the connection is ready."""
return self._ready_event.is_set()
def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
@ -282,7 +265,7 @@ class APINoiseFrameHelper(APIFrameHelper):
]
)
self._transport.write(header + frame)
except OSError as err:
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
async def perform_handshake(self) -> None:

View File

@ -557,7 +557,8 @@ class APIClient:
unsub()
except (KeyError, ValueError):
_LOGGER.warning(
"%s: Bluetooth device connection timed out but already unsubscribed",
"%s: Bluetooth device connection timed out but already unsubscribed "
"(likely due to unexpected disconnect)",
addr,
)
raise TimeoutAPIError(

View File

@ -1,4 +1,5 @@
import asyncio
import contextvars
import enum
import logging
import socket
@ -34,6 +35,7 @@ from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
BadNameAPIError,
HandshakeAPIError,
InvalidAuthAPIError,
PingFailedAPIError,
ProtocolAPIError,
@ -52,6 +54,10 @@ INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
"in_do_connect"
)
@dataclass
class ConnectionParams:
@ -111,8 +117,9 @@ class APIConnection:
self._ping_stop_event = asyncio.Event()
self._connect_lock: asyncio.Lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None
self._connect_task: Optional[asyncio.Task[None]] = None
self._fatal_exception: Optional[Exception] = None
self._expected_disconnect = False
@property
def connection_state(self) -> ConnectionState:
@ -123,36 +130,36 @@ class APIConnection:
"""Set the friendly log name for this connection."""
self.log_name = name
async def _cleanup(self) -> None:
def _cleanup(self) -> None:
"""Clean up all resources that have been allocated.
Safe to call multiple times.
"""
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
if self._connect_task is not None and not in_do_connect.get(False):
self._connect_task.cancel()
self._connect_task = None
async def _do_cleanup() -> None:
async with self._connect_lock:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
if self._frame_helper is not None:
self._frame_helper.close()
self._frame_helper = None
if self._frame_helper is not None:
self._frame_helper.close()
self._frame_helper = None
if self._socket is not None:
self._socket.close()
self._socket = None
if self._socket is not None:
self._socket.close()
self._socket = None
if not self._on_stop_called and self._connect_complete:
# Ensure on_stop is called
asyncio.create_task(self.on_stop())
self._on_stop_called = True
if not self._on_stop_called and self._connect_complete:
# Ensure on_stop is called
asyncio.create_task(self.on_stop())
self._on_stop_called = True
# Note: we don't explicitly cancel the ping/read task here
# That's because if not written right the ping/read task could cancel
# themselves, effectively ending execution after _cleanup which may be unexpected
self._ping_stop_event.set()
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(_do_cleanup())
# Note: we don't explicitly cancel the ping/read task here
# That's because if not written right the ping/read task could cancel
# themselves, effectively ending execution after _cleanup which may be unexpected
self._ping_stop_event.set()
async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address."""
@ -216,7 +223,7 @@ class APIConnection:
_, fh = await loop.create_connection(
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
on_error=self._report_fatal_error_and_cleanup_task,
on_error=self._report_fatal_error,
),
sock=self._socket,
)
@ -226,14 +233,20 @@ class APIConnection:
noise_psk=self._params.noise_psk,
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error_and_cleanup_task,
on_error=self._report_fatal_error,
),
sock=self._socket,
)
self._frame_helper = fh
self._connection_state = ConnectionState.SOCKET_OPENED
await fh.perform_handshake()
try:
async with async_timeout.timeout(30.0):
await fh.perform_handshake()
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
except asyncio.TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
@ -271,8 +284,6 @@ class APIConnection:
f"Server sent a different name '{resp.name}'", resp.name
)
self._connection_state = ConnectionState.CONNECTED
async def _connect_start_ping(self) -> None:
"""Step 5 in connect process: start the ping loop."""
@ -286,26 +297,26 @@ class APIConnection:
pass
# Re-check connection state
if not self._is_socket_open:
return # type: ignore[unreachable]
if not self._is_socket_open or self._ping_stop_event.is_set():
return
try:
await self._ping()
except TimeoutAPIError:
_LOGGER.info("%s: Ping timed out!", self.log_name)
await self._report_fatal_error(PingFailedAPIError())
_LOGGER.debug("%s: Ping timed out!", self.log_name)
self._report_fatal_error(PingFailedAPIError())
return
except APIConnectionError as err:
_LOGGER.info("%s: Ping Failed: %s", self.log_name, err)
await self._report_fatal_error(err)
_LOGGER.debug("%s: Ping Failed: %s", self.log_name, err)
self._report_fatal_error(err)
return
except Exception as err: # pylint: disable=broad-except
_LOGGER.info(
_LOGGER.error(
"%s: Unexpected error during ping:",
self.log_name,
exc_info=True,
)
await self._report_fatal_error(err)
self._report_fatal_error(err)
return
asyncio.create_task(_keep_alive_loop())
@ -317,35 +328,45 @@ class APIConnection:
)
async def _do_connect() -> None:
in_do_connect.set(True)
addr = await self._connect_resolve_host()
await self._connect_socket_connect(addr)
await self._connect_init_frame_helper()
await self._connect_hello()
await self._connect_start_ping()
if login:
await self.login()
await self.login(check_connected=False)
# A connection lock must be created to avoid potential issues where
# connect has succeeded but not yet returned, followed by a disconnect.
# See esphome/aioesphomeapi#258 for more information
async with self._connect_lock:
try:
# Allow 2 minutes for connect; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
# does not have a timeout
async with async_timeout.timeout(120.0):
await _do_connect()
except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occured during connect
self._connection_state = ConnectionState.CLOSED
await self._cleanup()
raise
self._connect_task = asyncio.create_task(_do_connect())
try:
# Allow 2 minutes for connect; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
# does not have a timeout
async with async_timeout.timeout(120.0):
await self._connect_task
except asyncio.CancelledError:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._connection_state = ConnectionState.CLOSED
self._cleanup()
raise self._fatal_exception or APIConnectionError("Connection cancelled")
except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occured during connect
self._connection_state = ConnectionState.CLOSED
self._cleanup()
raise
else:
self._connection_state = ConnectionState.CONNECTED
self._connect_complete = True
async def login(self) -> None:
async def login(self, check_connected: bool = True) -> None:
"""Send a login (ConnectRequest) and await the response."""
self._check_connected()
if check_connected:
# On first connect, we don't want to check if we're connected
# because we don't set the connection state until after login
# is complete
self._check_connected()
if self._is_authenticated:
raise APIConnectionError("Already logged in!")
@ -359,7 +380,7 @@ class APIConnection:
# We don't know what state the device may be in after ConnectRequest
# was already sent
_LOGGER.debug("%s: Login timed out", self.log_name)
await self._report_fatal_error(err)
self._report_fatal_error(err)
raise
if resp.invalid_password:
@ -393,7 +414,6 @@ class APIConnection:
frame_helper = self._frame_helper
assert frame_helper is not None
assert frame_helper.ready, "Frame helper not ready"
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
if not message_type:
raise ValueError(f"Message type id not found for type {type(msg)}")
@ -411,7 +431,7 @@ class APIConnection:
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
self._report_fatal_error_and_cleanup_task(err)
self._report_fatal_error(err)
raise
def add_message_callback(
@ -525,26 +545,7 @@ class APIConnection:
return res[0]
def _handle_fatal_error(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation."""
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._read_exception_handlers.clear()
def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None:
"""Handle a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
can no longer be used.
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
self._handle_fatal_error(err)
asyncio.create_task(self._cleanup())
async def _report_fatal_error(self, err: Exception) -> None:
def _report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
@ -553,8 +554,20 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
self._handle_fatal_error(err)
await self._cleanup()
if not self._expected_disconnect and not self._fatal_exception:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
self.log_name,
err or type(err),
exc_info=not str(err), # Log the full stack on empty error string
)
self._fatal_exception = err
self._connection_state = ConnectionState.CLOSED
for handler in self._read_exception_handlers[:]:
handler(err)
self._read_exception_handlers.clear()
self._cleanup()
def _process_packet(self, pkt: Packet) -> None:
"""Process a packet from the socket."""
@ -575,9 +588,7 @@ class APIConnection:
e,
exc_info=True,
)
self._report_fatal_error_and_cleanup_task(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}"))
raise
msg_type = type(msg)
@ -595,7 +606,8 @@ class APIConnection:
if isinstance(msg, DisconnectRequest):
self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED
asyncio.create_task(self._cleanup())
self._expected_disconnect = True
self._cleanup()
elif isinstance(msg, PingRequest):
self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest):
@ -612,6 +624,7 @@ class APIConnection:
# already disconnected
return
self._expected_disconnect = True
try:
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
@ -620,11 +633,12 @@ class APIConnection:
pass
self._connection_state = ConnectionState.CLOSED
await self._cleanup()
self._cleanup()
async def force_disconnect(self) -> None:
self._connection_state = ConnectionState.CLOSED
await self._cleanup()
self._expected_disconnect = True
self._cleanup()
@property
def api_version(self) -> Optional[APIVersion]: