Reduce duplicate code between connection and frame helper

This commit is contained in:
J. Nick Koston 2023-11-27 22:13:37 -06:00
parent 5fb9c9243b
commit 361ddebeaf
No known key found for this signature in database
4 changed files with 89 additions and 74 deletions

View File

@ -42,3 +42,5 @@ cdef class APIFrameHelper:
cpdef void write_packets(self, list packets, bint debug_enabled)
cdef void _write_bytes(self, object data, bint debug_enabled)
cdef get_handshake_future(self)

View File

@ -5,7 +5,7 @@ import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError
from ..core import SocketClosedAPIError
if TYPE_CHECKING:
from ..connection import APIConnection
@ -23,6 +23,7 @@ WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
_int = int
_bytes = bytes
_float = float
class APIFrameHelper:
@ -135,21 +136,9 @@ class APIFrameHelper:
bitpos += 7
return -1
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
handshake_handle = self._loop.call_at(
self._loop.time() + timeout,
self._set_ready_future_exception,
asyncio.TimeoutError,
)
try:
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
def get_handshake_future(self) -> None:
"""Get the handshake future."""
return self._ready_future
@abstractmethod
def write_packets(

View File

@ -9,6 +9,7 @@ cdef dict PROTO_TO_MESSAGE_TYPE
cdef set OPEN_STATES
cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef object HANDSHAKE_TIMEOUT
cdef bint TYPE_CHECKING
@ -16,12 +17,13 @@ cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE
cdef object NO_PASSWORD_CONNECT_REQUEST
cdef object asyncio_timeout
cdef object CancelledError
cdef object asyncio_TimeoutError
cdef object ConnectResponse
cdef object ConnectRequest, ConnectResponse
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest, GetTimeResponse
@ -53,6 +55,10 @@ cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
cdef object CONNECTION_STATE_CONNECTED
cdef object CONNECTION_STATE_CLOSED
cdef object make_hello_request
cdef object handle_timeout
cdef object handle_complex_message
@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
@ -91,43 +97,45 @@ cdef class APIConnection:
cdef public str received_name
cdef public object resolved_addr_info
cpdef send_message(self, object msg)
cpdef void send_message(self, object msg)
cdef send_messages(self, tuple messages)
cdef void send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set)
cpdef void process_packet(self, object msg_type_proto, object data)
cpdef _async_cancel_pong_timer(self)
cdef void _async_cancel_pong_timer(self)
cpdef _async_schedule_keep_alive(self, object now)
cdef void _async_schedule_keep_alive(self, object now)
cdef _cleanup(self)
cdef void _cleanup(self)
cpdef set_log_name(self, str name)
cdef _make_connect_request(self)
cdef _process_hello_resp(self, object resp)
cdef void _process_hello_resp(self, object resp)
cdef _process_login_response(self, object hello_response)
cdef void _process_login_response(self, object hello_response)
cdef _set_connection_state(self, object state)
cdef void _set_connection_state(self, object state)
cpdef report_fatal_error(self, Exception err)
@cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cdef void _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cpdef add_message_callback(self, object on_message, tuple msg_types)
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
cpdef void _remove_message_callback(self, object on_message, tuple msg_types)
cpdef _handle_disconnect_request_internal(self, object msg)
cpdef void _handle_disconnect_request_internal(self, object msg)
cpdef _handle_ping_request_internal(self, object msg)
cpdef void _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg)
cpdef void _handle_get_time_request_internal(self, object msg)
cdef _set_fatal_exception_if_unset(self, Exception err)
cdef void _set_fatal_exception_if_unset(self, Exception err)
cdef void _register_internal_message_handlers(self)

View File

@ -12,7 +12,7 @@ import time
from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError
from dataclasses import astuple, dataclass
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable
from google.protobuf import message
@ -66,6 +66,7 @@ DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@ -131,6 +132,46 @@ CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
def _make_hello_request(client_info: str) -> HelloRequest:
"""Make a HelloRequest."""
return HelloRequest(
client_info=client_info,
api_version_major=1,
api_version_minor=9,
)
_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
make_hello_request = _cached_make_hello_request
def _handle_timeout(fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
handle_timeout = _handle_timeout
def _handle_complex_message(
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
handle_complex_message = _handle_complex_message
class APIConnection:
"""This class represents _one_ connection to a remote native API device.
@ -359,24 +400,23 @@ class APIConnection:
# Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake
self._frame_helper = fh
future = self._frame_helper.get_handshake_future()
handshake_handle = self._loop.call_at(
self._loop.time() + HANDSHAKE_TIMEOUT, handle_timeout, future
)
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
await future
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
finally:
handshake_handle.cancel()
self._set_connection_state(CONNECTION_STATE_HANDSHAKE_COMPLETE)
async def _connect_hello_login(self, login: bool) -> None:
"""Step 4 in connect process: send hello and login and get api version."""
messages = [
HelloRequest(
client_info=self._params.client_info,
api_version_major=1,
api_version_minor=9,
)
]
messages = [make_hello_request(self._params.client_info)]
msg_types = [HelloResponse]
if login:
messages.append(self._make_connect_request())
@ -600,10 +640,9 @@ class APIConnection:
def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest."""
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
return connect
return ConnectRequest(password=self._params.password)
return NO_PASSWORD_CONNECT_REQUEST
def send_message(self, msg: message.Message) -> None:
"""Send a message to the remote."""
@ -679,26 +718,6 @@ class APIConnection:
# we register the handler after sending the message
return self.add_message_callback(on_message, msg_types)
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio_TimeoutError)
def _handle_complex_message(
self,
fut: asyncio.Future[None],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
if not fut.done():
if do_append is None or do_append(resp):
responses.append(resp)
if do_stop is None or do_stop(resp):
fut.set_result(None)
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
self,
messages: tuple[message.Message, ...],
@ -724,19 +743,16 @@ class APIConnection:
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()
responses: list[message.Message] = []
handler = self._handle_complex_message
on_message = partial(handler, fut, responses, do_append, do_stop)
read_exception_futures = self._read_exception_futures
on_message = partial(handle_complex_message, fut, responses, do_append, do_stop)
self._add_message_callback_without_remove(on_message, msg_types)
read_exception_futures.add(fut)
self._read_exception_futures.add(fut)
# Now safe to await since we have registered the handler
# We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
timeout_handle = loop.call_at(loop.time() + timeout, self._handle_timeout, fut)
timeout_handle = loop.call_at(loop.time() + timeout, handle_timeout, fut)
timeout_expired = False
try:
await fut
@ -750,7 +766,7 @@ class APIConnection:
if not timeout_expired:
timeout_handle.cancel()
self._remove_message_callback(on_message, msg_types)
read_exception_futures.discard(fut)
self._read_exception_futures.discard(fut)
return responses
@ -775,7 +791,7 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if not self._fatal_exception:
if self._fatal_exception is None:
if self._expected_disconnect is False:
# Only log the first error
_LOGGER.warning(
@ -810,7 +826,7 @@ class APIConnection:
return
try:
msg = klass()
msg: message.Message = klass()
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
@ -895,7 +911,7 @@ class APIConnection:
async def disconnect(self) -> None:
"""Disconnect from the API."""
if self._finish_connect_task:
if self._finish_connect_task is not None:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.