Refactor frame helpers to share more code (#500)

This commit is contained in:
J. Nick Koston 2023-07-21 03:11:04 -05:00 committed by GitHub
parent 85c2638cba
commit 49d86f940e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 37 additions and 41 deletions

View File

@ -6,7 +6,7 @@ from abc import abstractmethod
from functools import partial
from typing import Callable, cast
from ..core import SocketClosedAPIError
from ..core import HandshakeAPIError, SocketClosedAPIError
_LOGGER = logging.getLogger(__name__)
@ -24,11 +24,12 @@ class APIFrameHelper(asyncio.Protocol):
"""Helper class to handle the API frame protocol."""
__slots__ = (
"_loop",
"_on_pkt",
"_on_error",
"_transport",
"_writer",
"_connected_event",
"_ready_future",
"_buffer",
"_buffer_len",
"_pos",
@ -45,11 +46,13 @@ class APIFrameHelper(asyncio.Protocol):
log_name: str,
) -> None:
"""Initialize the API frame helper."""
loop = asyncio.get_event_loop()
self._loop = loop
self._on_pkt = on_pkt
self._on_error = on_error
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._connected_event = asyncio.Event()
self._ready_future = self._loop.create_future()
self._buffer = bytearray()
self._buffer_len = 0
self._pos = 0
@ -57,6 +60,10 @@ class APIFrameHelper(asyncio.Protocol):
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def _set_ready_future_exception(self, exc: Exception) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def _read_exactly(self, length: int) -> bytearray | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
@ -66,9 +73,19 @@ class APIFrameHelper(asyncio.Protocol):
self._pos = new_pos
return self._buffer[original_pos:new_pos]
@abstractmethod
async def perform_handshake(self) -> None:
"""Perform the handshake."""
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
handshake_handle = self._loop.call_later(
timeout, self._set_ready_future_exception, asyncio.TimeoutError()
)
try:
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
@abstractmethod
def write_packet(self, type_: int, data: bytes) -> None:
@ -78,7 +95,6 @@ class APIFrameHelper(asyncio.Protocol):
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write
self._connected_event.set()
def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import base64
import logging
from enum import Enum
@ -65,7 +64,6 @@ class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
__slots__ = (
"_ready_future",
"_noise_psk",
"_expected_name",
"_state",
@ -75,7 +73,6 @@ class APINoiseFrameHelper(APIFrameHelper):
"_decrypt",
"_encrypt",
"_is_ready",
"_loop",
)
def __init__(
@ -89,9 +86,6 @@ class APINoiseFrameHelper(APIFrameHelper):
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error, client_info, log_name)
loop = asyncio.get_event_loop()
self._loop = loop
self._ready_future = loop.create_future()
self._noise_psk = noise_psk
self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO)
@ -101,10 +95,6 @@ class APINoiseFrameHelper(APIFrameHelper):
self._setup_proto()
self._is_ready = False
def _set_ready_future_exception(self, exc: Exception) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def _set_state(self, state: NoiseConnectionState) -> None:
"""Set the current state."""
self._state = state
@ -141,20 +131,10 @@ class APINoiseFrameHelper(APIFrameHelper):
exc.__cause__ = original_exc
super()._handle_error(exc)
async def perform_handshake(self) -> None:
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
self._send_hello_handshake()
handshake_handle = self._loop.call_later(
60, self._set_ready_future_exception, asyncio.TimeoutError()
)
try:
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
finally:
handshake_handle.cancel()
await super().perform_handshake(timeout)
def data_received(self, data: bytes) -> None:
self._buffer += data

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING
@ -13,6 +14,11 @@ _LOGGER = logging.getLogger(__name__)
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
super().connection_made(transport)
self._ready_future.set_result(None)
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket.
@ -32,10 +38,6 @@ class APIPlaintextFrameHelper(APIFrameHelper):
f"{self._log_name}: Error while writing data: {err}"
) from err
async def perform_handshake(self) -> None:
"""Perform the handshake."""
await self._connected_event.wait()
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
self._buffer += data
self._buffer_len += len(data)

View File

@ -349,8 +349,7 @@ class APIConnection:
self._frame_helper = fh
self._set_connection_state(ConnectionState.SOCKET_OPENED)
try:
async with async_timeout.timeout(HANDSHAKE_TIMEOUT):
await fh.perform_handshake()
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
except asyncio.TimeoutError as err:

View File

@ -149,7 +149,7 @@ async def test_noise_frame_helper_incorrect_key():
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake()
await helper.perform_handshake(30)
@pytest.mark.asyncio
@ -192,7 +192,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
helper.data_received(in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake()
await helper.perform_handshake(30)
@pytest.mark.asyncio
@ -233,4 +233,4 @@ async def test_noise_incorrect_name():
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
await helper.perform_handshake()
await helper.perform_handshake(30)

View File

@ -59,9 +59,8 @@ def _get_mock_protocol(conn: APIConnection):
client_info="mock",
log_name="mock_device",
)
protocol._connected_event.set()
protocol._transport = MagicMock()
protocol._writer = MagicMock()
transport = MagicMock()
protocol.connection_made(transport)
return protocol