mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-16 11:05:10 +01:00
Reduce duplicate code between connection and frame helper
This commit is contained in:
parent
5fb9c9243b
commit
361ddebeaf
@ -42,3 +42,5 @@ cdef class APIFrameHelper:
|
||||
cpdef void write_packets(self, list packets, bint debug_enabled)
|
||||
|
||||
cdef void _write_bytes(self, object data, bint debug_enabled)
|
||||
|
||||
cdef get_handshake_future(self)
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Callable, cast
|
||||
|
||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||
from ..core import SocketClosedAPIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import APIConnection
|
||||
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
_float = float
|
||||
|
||||
|
||||
class APIFrameHelper:
|
||||
@ -135,21 +136,9 @@ class APIFrameHelper:
|
||||
bitpos += 7
|
||||
return -1
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
handshake_handle = self._loop.call_at(
|
||||
self._loop.time() + timeout,
|
||||
self._set_ready_future_exception,
|
||||
asyncio.TimeoutError,
|
||||
)
|
||||
try:
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
def get_handshake_future(self) -> None:
|
||||
"""Get the handshake future."""
|
||||
return self._ready_future
|
||||
|
||||
@abstractmethod
|
||||
def write_packets(
|
||||
|
@ -9,6 +9,7 @@ cdef dict PROTO_TO_MESSAGE_TYPE
|
||||
cdef set OPEN_STATES
|
||||
|
||||
cdef float KEEP_ALIVE_TIMEOUT_RATIO
|
||||
cdef object HANDSHAKE_TIMEOUT
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
@ -16,12 +17,13 @@ cdef object DISCONNECT_REQUEST_MESSAGE
|
||||
cdef object DISCONNECT_RESPONSE_MESSAGE
|
||||
cdef object PING_REQUEST_MESSAGE
|
||||
cdef object PING_RESPONSE_MESSAGE
|
||||
cdef object NO_PASSWORD_CONNECT_REQUEST
|
||||
|
||||
cdef object asyncio_timeout
|
||||
cdef object CancelledError
|
||||
cdef object asyncio_TimeoutError
|
||||
|
||||
cdef object ConnectResponse
|
||||
cdef object ConnectRequest, ConnectResponse
|
||||
cdef object DisconnectRequest
|
||||
cdef object PingRequest
|
||||
cdef object GetTimeRequest, GetTimeResponse
|
||||
@ -53,6 +55,10 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
|
||||
cdef object CONNECTION_STATE_CONNECTED
|
||||
cdef object CONNECTION_STATE_CLOSED
|
||||
|
||||
cdef object make_hello_request
|
||||
cdef object handle_timeout
|
||||
cdef object handle_complex_message
|
||||
|
||||
@cython.dataclasses.dataclass
|
||||
cdef class ConnectionParams:
|
||||
cdef public str address
|
||||
@ -91,43 +97,45 @@ cdef class APIConnection:
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
|
||||
cpdef send_message(self, object msg)
|
||||
cpdef void send_message(self, object msg)
|
||||
|
||||
cdef send_messages(self, tuple messages)
|
||||
cdef void send_messages(self, tuple messages)
|
||||
|
||||
@cython.locals(handlers=set, handlers_copy=set)
|
||||
cpdef void process_packet(self, object msg_type_proto, object data)
|
||||
|
||||
cpdef _async_cancel_pong_timer(self)
|
||||
cdef void _async_cancel_pong_timer(self)
|
||||
|
||||
cpdef _async_schedule_keep_alive(self, object now)
|
||||
cdef void _async_schedule_keep_alive(self, object now)
|
||||
|
||||
cdef _cleanup(self)
|
||||
cdef void _cleanup(self)
|
||||
|
||||
cpdef set_log_name(self, str name)
|
||||
|
||||
cdef _make_connect_request(self)
|
||||
|
||||
cdef _process_hello_resp(self, object resp)
|
||||
cdef void _process_hello_resp(self, object resp)
|
||||
|
||||
cdef _process_login_response(self, object hello_response)
|
||||
cdef void _process_login_response(self, object hello_response)
|
||||
|
||||
cdef _set_connection_state(self, object state)
|
||||
cdef void _set_connection_state(self, object state)
|
||||
|
||||
cpdef report_fatal_error(self, Exception err)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
|
||||
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
cpdef _handle_disconnect_request_internal(self, object msg)
|
||||
cpdef void _handle_disconnect_request_internal(self, object msg)
|
||||
|
||||
cpdef _handle_ping_request_internal(self, object msg)
|
||||
cpdef void _handle_ping_request_internal(self, object msg)
|
||||
|
||||
cpdef _handle_get_time_request_internal(self, object msg)
|
||||
cpdef void _handle_get_time_request_internal(self, object msg)
|
||||
|
||||
cdef _set_fatal_exception_if_unset(self, Exception err)
|
||||
cdef void _set_fatal_exception_if_unset(self, Exception err)
|
||||
|
||||
cdef void _register_internal_message_handlers(self)
|
||||
|
@ -12,7 +12,7 @@ import time
|
||||
from asyncio import CancelledError
|
||||
from asyncio import TimeoutError as asyncio_TimeoutError
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from google.protobuf import message
|
||||
@ -66,6 +66,7 @@ DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
|
||||
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
|
||||
PING_REQUEST_MESSAGE = PingRequest()
|
||||
PING_RESPONSE_MESSAGE = PingResponse()
|
||||
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
@ -131,6 +132,46 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
|
||||
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
|
||||
|
||||
|
||||
def _make_hello_request(client_info: str) -> HelloRequest:
|
||||
"""Make a HelloRequest."""
|
||||
return HelloRequest(
|
||||
client_info=client_info,
|
||||
api_version_major=1,
|
||||
api_version_minor=9,
|
||||
)
|
||||
|
||||
|
||||
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
|
||||
make_hello_request = _cached_make_hello_request
|
||||
|
||||
|
||||
def _handle_timeout(fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
|
||||
handle_timeout = _handle_timeout
|
||||
|
||||
|
||||
def _handle_complex_message(
|
||||
fut: asyncio.Future[None],
|
||||
responses: list[message.Message],
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
resp: message.Message,
|
||||
) -> None:
|
||||
"""Handle a message that is part of a response."""
|
||||
if not fut.done():
|
||||
if do_append is None or do_append(resp):
|
||||
responses.append(resp)
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
|
||||
handle_complex_message = _handle_complex_message
|
||||
|
||||
|
||||
class APIConnection:
|
||||
"""This class represents _one_ connection to a remote native API device.
|
||||
|
||||
@ -359,24 +400,23 @@ class APIConnection:
|
||||
# Set the frame helper right away to ensure
|
||||
# the socket gets closed if we fail to handshake
|
||||
self._frame_helper = fh
|
||||
|
||||
future = self._frame_helper.get_handshake_future()
|
||||
handshake_handle = self._loop.call_at(
|
||||
self._loop.time() + HANDSHAKE_TIMEOUT, handle_timeout, future
|
||||
)
|
||||
try:
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
await future
|
||||
except asyncio_TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
|
||||
|
||||
async def _connect_hello_login(self, login: bool) -> None:
|
||||
"""Step 4 in connect process: send hello and login and get api version."""
|
||||
messages = [
|
||||
HelloRequest(
|
||||
client_info=self._params.client_info,
|
||||
api_version_major=1,
|
||||
api_version_minor=9,
|
||||
)
|
||||
]
|
||||
messages = [make_hello_request(self._params.client_info)]
|
||||
msg_types = [HelloResponse]
|
||||
if login:
|
||||
messages.append(self._make_connect_request())
|
||||
@ -600,10 +640,9 @@ class APIConnection:
|
||||
|
||||
def _make_connect_request(self) -> ConnectRequest:
|
||||
"""Make a ConnectRequest."""
|
||||
connect = ConnectRequest()
|
||||
if self._params.password is not None:
|
||||
connect.password = self._params.password
|
||||
return connect
|
||||
return ConnectRequest(password=self._params.password)
|
||||
return NO_PASSWORD_CONNECT_REQUEST
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a message to the remote."""
|
||||
@ -679,26 +718,6 @@ class APIConnection:
|
||||
# we register the handler after sending the message
|
||||
return self.add_message_callback(on_message, msg_types)
|
||||
|
||||
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
|
||||
"""Handle a timeout."""
|
||||
if not fut.done():
|
||||
fut.set_exception(asyncio_TimeoutError)
|
||||
|
||||
def _handle_complex_message(
|
||||
self,
|
||||
fut: asyncio.Future[None],
|
||||
responses: list[message.Message],
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
resp: message.Message,
|
||||
) -> None:
|
||||
"""Handle a message that is part of a response."""
|
||||
if not fut.done():
|
||||
if do_append is None or do_append(resp):
|
||||
responses.append(resp)
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
|
||||
self,
|
||||
messages: tuple[message.Message, ...],
|
||||
@ -724,19 +743,16 @@ class APIConnection:
|
||||
# Unsafe to await between sending the message and registering the handler
|
||||
fut: asyncio.Future[None] = loop.create_future()
|
||||
responses: list[message.Message] = []
|
||||
handler = self._handle_complex_message
|
||||
on_message = partial(handler, fut, responses, do_append, do_stop)
|
||||
|
||||
read_exception_futures = self._read_exception_futures
|
||||
on_message = partial(handle_complex_message, fut, responses, do_append, do_stop)
|
||||
self._add_message_callback_without_remove(on_message, msg_types)
|
||||
|
||||
read_exception_futures.add(fut)
|
||||
self._read_exception_futures.add(fut)
|
||||
# Now safe to await since we have registered the handler
|
||||
|
||||
# We must not await without a finally or
|
||||
# the message could fail to be removed if the
|
||||
# the await is cancelled
|
||||
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut)
|
||||
timeout_handle = loop.call_at(loop.time() + timeout, handle_timeout, fut)
|
||||
timeout_expired = False
|
||||
try:
|
||||
await fut
|
||||
@ -750,7 +766,7 @@ class APIConnection:
|
||||
if not timeout_expired:
|
||||
timeout_handle.cancel()
|
||||
self._remove_message_callback(on_message, msg_types)
|
||||
read_exception_futures.discard(fut)
|
||||
self._read_exception_futures.discard(fut)
|
||||
|
||||
return responses
|
||||
|
||||
@ -775,7 +791,7 @@ class APIConnection:
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
if not self._fatal_exception:
|
||||
if self._fatal_exception is None:
|
||||
if self._expected_disconnect is False:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
@ -810,7 +826,7 @@ class APIConnection:
|
||||
return
|
||||
|
||||
try:
|
||||
msg = klass()
|
||||
msg: message.Message = klass()
|
||||
# MergeFromString instead of ParseFromString since
|
||||
# ParseFromString will clear the message first and
|
||||
# the msg is already empty.
|
||||
@ -895,7 +911,7 @@ class APIConnection:
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the API."""
|
||||
if self._finish_connect_task:
|
||||
if self._finish_connect_task is not None:
|
||||
# Try to wait for the handshake to finish so we can send
|
||||
# a disconnect request. If it doesn't finish in time
|
||||
# we will just close the socket.
|
||||
|
Loading…
Reference in New Issue
Block a user