mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-22 02:42:19 +01:00
Refactor cleanup to be a normal function (#355)
This commit is contained in:
parent
2886d361f0
commit
284b767d8d
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import abstractmethod, abstractproperty
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union, cast
|
||||
@ -49,15 +49,10 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._on_error = on_error
|
||||
self._transport: Optional[asyncio.Transport] = None
|
||||
self.read_lock = asyncio.Lock()
|
||||
self._closed_event = asyncio.Event()
|
||||
self._connected_event = asyncio.Event()
|
||||
self._buffer = bytearray()
|
||||
self._pos = 0
|
||||
|
||||
@abstractproperty # pylint: disable=deprecated-decorator
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
|
||||
def _init_read(self, length: int) -> Optional[bytearray]:
|
||||
"""Start reading a packet from the buffer."""
|
||||
self._pos = 0
|
||||
@ -90,7 +85,6 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self.close()
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
self._closed_event.set()
|
||||
self._on_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
@ -103,7 +97,6 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self._closed_event.set()
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
|
||||
@ -111,11 +104,6 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
return self._connected_event.is_set()
|
||||
|
||||
def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None:
|
||||
"""Complete reading a packet from the buffer."""
|
||||
del self._buffer[: self._pos]
|
||||
@ -138,7 +126,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
|
||||
try:
|
||||
self._transport.write(data)
|
||||
except (ConnectionResetError, OSError) as err:
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
@ -250,11 +238,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._state = NoiseConnectionState.HELLO
|
||||
self._setup_proto()
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Return if the connection is ready."""
|
||||
return self._ready_event.is_set()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
# Make sure we set the ready event if its not already set
|
||||
@ -282,7 +265,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
]
|
||||
)
|
||||
self._transport.write(header + frame)
|
||||
except OSError as err:
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
|
@ -557,7 +557,8 @@ class APIClient:
|
||||
unsub()
|
||||
except (KeyError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"%s: Bluetooth device connection timed out but already unsubscribed",
|
||||
"%s: Bluetooth device connection timed out but already unsubscribed "
|
||||
"(likely due to unexpected disconnect)",
|
||||
addr,
|
||||
)
|
||||
raise TimeoutAPIError(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
@ -34,6 +35,7 @@ from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
PingFailedAPIError,
|
||||
ProtocolAPIError,
|
||||
@ -52,6 +54,10 @@ INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
|
||||
"in_do_connect"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
@ -111,8 +117,9 @@ class APIConnection:
|
||||
|
||||
self._ping_stop_event = asyncio.Event()
|
||||
|
||||
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
||||
self._connect_task: Optional[asyncio.Task[None]] = None
|
||||
self._fatal_exception: Optional[Exception] = None
|
||||
self._expected_disconnect = False
|
||||
|
||||
@property
|
||||
def connection_state(self) -> ConnectionState:
|
||||
@ -123,36 +130,36 @@ class APIConnection:
|
||||
"""Set the friendly log name for this connection."""
|
||||
self.log_name = name
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up all resources that have been allocated.
|
||||
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||
# If we are being called from do_connect we
|
||||
# need to make sure we don't cancel the task
|
||||
# that called us
|
||||
if self._connect_task is not None and not in_do_connect.get(False):
|
||||
self._connect_task.cancel()
|
||||
self._connect_task = None
|
||||
|
||||
async def _do_cleanup() -> None:
|
||||
async with self._connect_lock:
|
||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
if not self._on_stop_called and self._connect_complete:
|
||||
# Ensure on_stop is called
|
||||
asyncio.create_task(self.on_stop())
|
||||
self._on_stop_called = True
|
||||
|
||||
if not self._on_stop_called and self._connect_complete:
|
||||
# Ensure on_stop is called
|
||||
asyncio.create_task(self.on_stop())
|
||||
self._on_stop_called = True
|
||||
|
||||
# Note: we don't explicitly cancel the ping/read task here
|
||||
# That's because if not written right the ping/read task could cancel
|
||||
# themselves, effectively ending execution after _cleanup which may be unexpected
|
||||
self._ping_stop_event.set()
|
||||
|
||||
if not self._cleanup_task or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(_do_cleanup())
|
||||
# Note: we don't explicitly cancel the ping/read task here
|
||||
# That's because if not written right the ping/read task could cancel
|
||||
# themselves, effectively ending execution after _cleanup which may be unexpected
|
||||
self._ping_stop_event.set()
|
||||
|
||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||
"""Step 1 in connect process: resolve the address."""
|
||||
@ -216,7 +223,7 @@ class APIConnection:
|
||||
_, fh = await loop.create_connection(
|
||||
lambda: APIPlaintextFrameHelper(
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error_and_cleanup_task,
|
||||
on_error=self._report_fatal_error,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
@ -226,14 +233,20 @@ class APIConnection:
|
||||
noise_psk=self._params.noise_psk,
|
||||
expected_name=self._params.expected_name,
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error_and_cleanup_task,
|
||||
on_error=self._report_fatal_error,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
|
||||
self._frame_helper = fh
|
||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||
await fh.perform_handshake()
|
||||
try:
|
||||
async with async_timeout.timeout(30.0):
|
||||
await fh.perform_handshake()
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
except asyncio.TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
@ -271,8 +284,6 @@ class APIConnection:
|
||||
f"Server sent a different name '{resp.name}'", resp.name
|
||||
)
|
||||
|
||||
self._connection_state = ConnectionState.CONNECTED
|
||||
|
||||
async def _connect_start_ping(self) -> None:
|
||||
"""Step 5 in connect process: start the ping loop."""
|
||||
|
||||
@ -286,26 +297,26 @@ class APIConnection:
|
||||
pass
|
||||
|
||||
# Re-check connection state
|
||||
if not self._is_socket_open:
|
||||
return # type: ignore[unreachable]
|
||||
if not self._is_socket_open or self._ping_stop_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
await self._ping()
|
||||
except TimeoutAPIError:
|
||||
_LOGGER.info("%s: Ping timed out!", self.log_name)
|
||||
await self._report_fatal_error(PingFailedAPIError())
|
||||
_LOGGER.debug("%s: Ping timed out!", self.log_name)
|
||||
self._report_fatal_error(PingFailedAPIError())
|
||||
return
|
||||
except APIConnectionError as err:
|
||||
_LOGGER.info("%s: Ping Failed: %s", self.log_name, err)
|
||||
await self._report_fatal_error(err)
|
||||
_LOGGER.debug("%s: Ping Failed: %s", self.log_name, err)
|
||||
self._report_fatal_error(err)
|
||||
return
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.info(
|
||||
_LOGGER.error(
|
||||
"%s: Unexpected error during ping:",
|
||||
self.log_name,
|
||||
exc_info=True,
|
||||
)
|
||||
await self._report_fatal_error(err)
|
||||
self._report_fatal_error(err)
|
||||
return
|
||||
|
||||
asyncio.create_task(_keep_alive_loop())
|
||||
@ -317,35 +328,45 @@ class APIConnection:
|
||||
)
|
||||
|
||||
async def _do_connect() -> None:
|
||||
in_do_connect.set(True)
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
await self._connect_start_ping()
|
||||
if login:
|
||||
await self.login()
|
||||
await self.login(check_connected=False)
|
||||
|
||||
# A connection lock must be created to avoid potential issues where
|
||||
# connect has succeeded but not yet returned, followed by a disconnect.
|
||||
# See esphome/aioesphomeapi#258 for more information
|
||||
async with self._connect_lock:
|
||||
try:
|
||||
# Allow 2 minutes for connect; this is only as a last measure
|
||||
# to protect from issues if some part of the connect process mistakenly
|
||||
# does not have a timeout
|
||||
async with async_timeout.timeout(120.0):
|
||||
await _do_connect()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occured during connect
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
raise
|
||||
self._connect_task = asyncio.create_task(_do_connect())
|
||||
|
||||
try:
|
||||
# Allow 2 minutes for connect; this is only as a last measure
|
||||
# to protect from issues if some part of the connect process mistakenly
|
||||
# does not have a timeout
|
||||
async with async_timeout.timeout(120.0):
|
||||
await self._connect_task
|
||||
except asyncio.CancelledError:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._cleanup()
|
||||
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occured during connect
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._cleanup()
|
||||
raise
|
||||
else:
|
||||
self._connection_state = ConnectionState.CONNECTED
|
||||
self._connect_complete = True
|
||||
|
||||
async def login(self) -> None:
|
||||
async def login(self, check_connected: bool = True) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
self._check_connected()
|
||||
if check_connected:
|
||||
# On first connect, we don't want to check if we're connected
|
||||
# because we don't set the connection state until after login
|
||||
# is complete
|
||||
self._check_connected()
|
||||
if self._is_authenticated:
|
||||
raise APIConnectionError("Already logged in!")
|
||||
|
||||
@ -359,7 +380,7 @@ class APIConnection:
|
||||
# We don't know what state the device may be in after ConnectRequest
|
||||
# was already sent
|
||||
_LOGGER.debug("%s: Login timed out", self.log_name)
|
||||
await self._report_fatal_error(err)
|
||||
self._report_fatal_error(err)
|
||||
raise
|
||||
|
||||
if resp.invalid_password:
|
||||
@ -393,7 +414,6 @@ class APIConnection:
|
||||
|
||||
frame_helper = self._frame_helper
|
||||
assert frame_helper is not None
|
||||
assert frame_helper.ready, "Frame helper not ready"
|
||||
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
||||
if not message_type:
|
||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||
@ -411,7 +431,7 @@ class APIConnection:
|
||||
# If writing packet fails, we don't know what state the frames
|
||||
# are in anymore and we have to close the connection
|
||||
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
||||
self._report_fatal_error_and_cleanup_task(err)
|
||||
self._report_fatal_error(err)
|
||||
raise
|
||||
|
||||
def add_message_callback(
|
||||
@ -525,26 +545,7 @@ class APIConnection:
|
||||
|
||||
return res[0]
|
||||
|
||||
def _handle_fatal_error(self, err: Exception) -> None:
|
||||
"""Handle a fatal error that occurred during an operation."""
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
self._read_exception_handlers.clear()
|
||||
|
||||
def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None:
|
||||
"""Handle a fatal error that occurred during an operation.
|
||||
|
||||
This should only be called for errors that mean the connection
|
||||
can no longer be used.
|
||||
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
self._handle_fatal_error(err)
|
||||
asyncio.create_task(self._cleanup())
|
||||
|
||||
async def _report_fatal_error(self, err: Exception) -> None:
|
||||
def _report_fatal_error(self, err: Exception) -> None:
|
||||
"""Report a fatal error that occurred during an operation.
|
||||
|
||||
This should only be called for errors that mean the connection
|
||||
@ -553,8 +554,20 @@ class APIConnection:
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
self._handle_fatal_error(err)
|
||||
await self._cleanup()
|
||||
if not self._expected_disconnect and not self._fatal_exception:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
"%s: Connection error occurred: %s",
|
||||
self.log_name,
|
||||
err or type(err),
|
||||
exc_info=not str(err), # Log the full stack on empty error string
|
||||
)
|
||||
self._fatal_exception = err
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
self._read_exception_handlers.clear()
|
||||
self._cleanup()
|
||||
|
||||
def _process_packet(self, pkt: Packet) -> None:
|
||||
"""Process a packet from the socket."""
|
||||
@ -575,9 +588,7 @@ class APIConnection:
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
self._report_fatal_error_and_cleanup_task(
|
||||
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
||||
)
|
||||
self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}"))
|
||||
raise
|
||||
|
||||
msg_type = type(msg)
|
||||
@ -595,7 +606,8 @@ class APIConnection:
|
||||
if isinstance(msg, DisconnectRequest):
|
||||
self.send_message(DisconnectResponse())
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
asyncio.create_task(self._cleanup())
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
elif isinstance(msg, PingRequest):
|
||||
self.send_message(PingResponse())
|
||||
elif isinstance(msg, GetTimeRequest):
|
||||
@ -612,6 +624,7 @@ class APIConnection:
|
||||
# already disconnected
|
||||
return
|
||||
|
||||
self._expected_disconnect = True
|
||||
try:
|
||||
await self.send_message_await_response(
|
||||
DisconnectRequest(), DisconnectResponse
|
||||
@ -620,11 +633,12 @@ class APIConnection:
|
||||
pass
|
||||
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
self._cleanup()
|
||||
|
||||
async def force_disconnect(self) -> None:
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
|
||||
@property
|
||||
def api_version(self) -> Optional[APIVersion]:
|
||||
|
Loading…
Reference in New Issue
Block a user