2023-07-19 22:33:28 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-07-18 21:28:56 +02:00
|
|
|
import asyncio
|
|
|
|
import logging
|
|
|
|
from abc import abstractmethod
|
|
|
|
from functools import partial
|
2023-07-19 22:33:28 +02:00
|
|
|
from typing import Callable, cast
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-07-21 10:11:04 +02:00
|
|
|
from ..core import HandshakeAPIError, SocketClosedAPIError
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
SOCKET_ERRORS = (
|
|
|
|
ConnectionResetError,
|
|
|
|
asyncio.IncompleteReadError,
|
|
|
|
OSError,
|
|
|
|
TimeoutError,
|
|
|
|
)
|
|
|
|
|
|
|
|
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
|
|
|
|
2023-10-12 20:12:39 +02:00
|
|
|
_int = int
|
2023-07-18 21:28:56 +02:00
|
|
|
|
2023-10-12 20:12:39 +02:00
|
|
|
|
|
|
|
class APIFrameHelper:
|
2023-07-18 21:28:56 +02:00
|
|
|
"""Helper class to handle the API frame protocol."""
|
|
|
|
|
|
|
|
__slots__ = (
|
2023-07-21 10:11:04 +02:00
|
|
|
"_loop",
|
2023-07-18 21:28:56 +02:00
|
|
|
"_on_pkt",
|
|
|
|
"_on_error",
|
|
|
|
"_transport",
|
|
|
|
"_writer",
|
2023-07-21 10:11:04 +02:00
|
|
|
"_ready_future",
|
2023-07-18 21:28:56 +02:00
|
|
|
"_buffer",
|
|
|
|
"_buffer_len",
|
|
|
|
"_pos",
|
|
|
|
"_client_info",
|
|
|
|
"_log_name",
|
|
|
|
"_debug_enabled",
|
|
|
|
)
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
on_pkt: Callable[[int, bytes], None],
|
|
|
|
on_error: Callable[[Exception], None],
|
|
|
|
client_info: str,
|
|
|
|
log_name: str,
|
|
|
|
) -> None:
|
|
|
|
"""Initialize the API frame helper."""
|
2023-07-21 10:11:04 +02:00
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
self._loop = loop
|
2023-07-18 21:28:56 +02:00
|
|
|
self._on_pkt = on_pkt
|
|
|
|
self._on_error = on_error
|
2023-07-19 22:33:28 +02:00
|
|
|
self._transport: asyncio.Transport | None = None
|
|
|
|
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
2023-07-21 10:11:04 +02:00
|
|
|
self._ready_future = self._loop.create_future()
|
2023-07-18 21:28:56 +02:00
|
|
|
self._buffer = bytearray()
|
|
|
|
self._buffer_len = 0
|
|
|
|
self._pos = 0
|
|
|
|
self._client_info = client_info
|
|
|
|
self._log_name = log_name
|
|
|
|
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
|
|
|
|
2023-07-21 10:11:04 +02:00
|
|
|
def _set_ready_future_exception(self, exc: Exception) -> None:
|
|
|
|
if not self._ready_future.done():
|
|
|
|
self._ready_future.set_exception(exc)
|
|
|
|
|
2023-10-12 20:12:39 +02:00
|
|
|
def _read_exactly(self, length: _int) -> bytearray | None:
|
2023-07-18 21:28:56 +02:00
|
|
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
|
|
|
original_pos = self._pos
|
|
|
|
new_pos = original_pos + length
|
|
|
|
if self._buffer_len < new_pos:
|
|
|
|
return None
|
|
|
|
self._pos = new_pos
|
|
|
|
return self._buffer[original_pos:new_pos]
|
|
|
|
|
2023-07-21 10:11:04 +02:00
|
|
|
async def perform_handshake(self, timeout: float) -> None:
|
|
|
|
"""Perform the handshake with the server."""
|
2023-09-04 19:56:23 +02:00
|
|
|
handshake_handle = self._loop.call_at(
|
|
|
|
self._loop.time() + timeout,
|
|
|
|
self._set_ready_future_exception,
|
|
|
|
asyncio.TimeoutError,
|
2023-07-21 10:11:04 +02:00
|
|
|
)
|
|
|
|
try:
|
|
|
|
await self._ready_future
|
|
|
|
except asyncio.TimeoutError as err:
|
|
|
|
raise HandshakeAPIError(
|
|
|
|
f"{self._log_name}: Timeout during handshake"
|
|
|
|
) from err
|
|
|
|
finally:
|
|
|
|
handshake_handle.cancel()
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def write_packet(self, type_: int, data: bytes) -> None:
|
|
|
|
"""Write a packet to the socket."""
|
|
|
|
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
|
"""Handle a new connection."""
|
|
|
|
self._transport = cast(asyncio.Transport, transport)
|
|
|
|
self._writer = self._transport.write
|
|
|
|
|
|
|
|
def _handle_error_and_close(self, exc: Exception) -> None:
|
|
|
|
self._handle_error(exc)
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
def _handle_error(self, exc: Exception) -> None:
|
|
|
|
self._on_error(exc)
|
|
|
|
|
2023-07-19 22:33:28 +02:00
|
|
|
def connection_lost(self, exc: Exception | None) -> None:
|
2023-10-12 20:12:39 +02:00
|
|
|
"""Handle the connection being lost."""
|
2023-07-18 21:28:56 +02:00
|
|
|
self._handle_error(
|
|
|
|
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
|
|
|
|
)
|
|
|
|
|
2023-07-19 22:33:28 +02:00
|
|
|
def eof_received(self) -> bool | None:
|
2023-10-12 20:12:39 +02:00
|
|
|
"""Handle EOF received."""
|
2023-07-18 21:28:56 +02:00
|
|
|
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
|
2023-10-12 20:12:39 +02:00
|
|
|
return False
|
2023-07-18 21:28:56 +02:00
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
"""Close the connection."""
|
|
|
|
if self._transport:
|
|
|
|
self._transport.close()
|
|
|
|
self._transport = None
|
|
|
|
self._writer = None
|
2023-10-12 20:12:39 +02:00
|
|
|
|
|
|
|
def pause_writing(self) -> None:
|
|
|
|
"""Stub."""
|
|
|
|
|
|
|
|
def resume_writing(self) -> None:
|
|
|
|
"""Stub."""
|