mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-04 09:19:37 +01:00
Refactor frame helpers to share more code (#500)
This commit is contained in:
parent
85c2638cba
commit
49d86f940e
@ -6,7 +6,7 @@ from abc import abstractmethod
|
||||
from functools import partial
|
||||
from typing import Callable, cast
|
||||
|
||||
from ..core import SocketClosedAPIError
|
||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -24,11 +24,12 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
__slots__ = (
|
||||
"_loop",
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_connected_event",
|
||||
"_ready_future",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
@ -45,11 +46,13 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._on_pkt = on_pkt
|
||||
self._on_error = on_error
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._connected_event = asyncio.Event()
|
||||
self._ready_future = self._loop.create_future()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
@ -57,6 +60,10 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
def _read_exactly(self, length: int) -> bytearray | None:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
@ -66,9 +73,19 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._pos = new_pos
|
||||
return self._buffer[original_pos:new_pos]
|
||||
|
||||
@abstractmethod
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
handshake_handle = self._loop.call_later(
|
||||
timeout, self._set_ready_future_exception, asyncio.TimeoutError()
|
||||
)
|
||||
try:
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
@ -78,7 +95,6 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"""Handle a new connection."""
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._writer = self._transport.write
|
||||
self._connected_event.set()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
self._handle_error(exc)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from enum import Enum
|
||||
@ -65,7 +64,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for noise encrypted connections."""
|
||||
|
||||
__slots__ = (
|
||||
"_ready_future",
|
||||
"_noise_psk",
|
||||
"_expected_name",
|
||||
"_state",
|
||||
@ -75,7 +73,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"_decrypt",
|
||||
"_encrypt",
|
||||
"_is_ready",
|
||||
"_loop",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -89,9 +86,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error, client_info, log_name)
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._ready_future = loop.create_future()
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._set_state(NoiseConnectionState.HELLO)
|
||||
@ -101,10 +95,6 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._setup_proto()
|
||||
self._is_ready = False
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
def _set_state(self, state: NoiseConnectionState) -> None:
|
||||
"""Set the current state."""
|
||||
self._state = state
|
||||
@ -141,20 +131,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
exc.__cause__ = original_exc
|
||||
super()._handle_error(exc)
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
self._send_hello_handshake()
|
||||
handshake_handle = self._loop.call_later(
|
||||
60, self._set_ready_future_exception, asyncio.TimeoutError()
|
||||
)
|
||||
try:
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
await super().perform_handshake(timeout)
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -13,6 +14,11 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
super().connection_made(transport)
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
@ -32,10 +38,6 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
await self._connected_event.wait()
|
||||
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
|
@ -349,8 +349,7 @@ class APIConnection:
|
||||
self._frame_helper = fh
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
try:
|
||||
async with async_timeout.timeout(HANDSHAKE_TIMEOUT):
|
||||
await fh.perform_handshake()
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
except asyncio.TimeoutError as err:
|
||||
|
@ -149,7 +149,7 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake()
|
||||
await helper.perform_handshake(30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -192,7 +192,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
helper.data_received(in_pkt[i : i + 1])
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake()
|
||||
await helper.perform_handshake(30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -233,4 +233,4 @@ async def test_noise_incorrect_name():
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
await helper.perform_handshake()
|
||||
await helper.perform_handshake(30)
|
||||
|
@ -59,9 +59,8 @@ def _get_mock_protocol(conn: APIConnection):
|
||||
client_info="mock",
|
||||
log_name="mock_device",
|
||||
)
|
||||
protocol._connected_event.set()
|
||||
protocol._transport = MagicMock()
|
||||
protocol._writer = MagicMock()
|
||||
transport = MagicMock()
|
||||
protocol.connection_made(transport)
|
||||
return protocol
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user