From c76d741cb8b93be81ea6cfe3c44166afaecc6db4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 18 Nov 2023 15:10:40 -0600 Subject: [PATCH] Refactor to reduce duplicate connection code (#644) --- aioesphomeapi/_frame_helper/base.pxd | 3 + aioesphomeapi/_frame_helper/base.py | 15 +++ aioesphomeapi/_frame_helper/noise.pxd | 5 +- aioesphomeapi/_frame_helper/noise.py | 116 +++++++-------------- aioesphomeapi/_frame_helper/plain_text.pxd | 4 +- aioesphomeapi/_frame_helper/plain_text.py | 22 +--- aioesphomeapi/reconnect_logic.py | 34 +++--- tests/test__frame_helper.py | 6 +- 8 files changed, 84 insertions(+), 121 deletions(-) diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 794357d..babfbb4 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -5,6 +5,7 @@ from ..connection cimport APIConnection cdef bint TYPE_CHECKING +cdef object WRITE_EXCEPTIONS cdef class APIFrameHelper: @@ -32,3 +33,5 @@ cdef class APIFrameHelper: cdef _remove_from_buffer(self) cpdef write_packets(self, list packets) + + cdef _write_bytes(self, bytes data) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index cf3860c..55f1a14 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -180,3 +180,18 @@ class APIFrameHelper: def resume_writing(self) -> None: """Stub.""" + + def _write_bytes(self, data: bytes) -> None: + """Write bytes to the socket.""" + if self._debug_enabled() is True: + _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex()) + + if TYPE_CHECKING: + assert self._writer is not None, "Writer is not set" + + try: + self._writer(data) + except WRITE_EXCEPTIONS as err: + raise SocketClosedAPIError( + f"{self._log_name}: Error while writing data: {err}" + ) from err diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index a50e907..fe32420 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -48,6 +48,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper): packet=tuple, data_len=cython.uint, frame=bytes, - frame_len=cython.uint + frame_len=cython.uint, + type_=object ) cpdef write_packets(self, list packets) + + cdef _error_on_incorrect_preamble(self, bytes msg) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 3ad3bb3..b85c90f 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -1,7 +1,6 @@ from __future__ import annotations import binascii -import logging from functools import partial from struct import Struct from typing import TYPE_CHECKING, Any, Callable @@ -20,15 +19,12 @@ from ..core import ( HandshakeAPIError, InvalidEncryptionKeyAPIError, ProtocolAPIError, - SocketAPIError, ) -from .base import WRITE_EXCEPTIONS, APIFrameHelper +from .base import APIFrameHelper if TYPE_CHECKING: from ..connection import APIConnection -_LOGGER = logging.getLogger(__name__) - PACK_NONCE = partial(Struct(" None: """Handle an error, and provide a good message when during hello.""" - if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO: - original_exc = exc + if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError): + original_exc: Exception = exc exc = HandshakeAPIError( f"{self._log_name}: The connection dropped immediately after encrypted hello; " "Try enabling encryption on the device or turning off " f"encryption on the client ({self._client_info})." ) exc.__cause__ = original_exc + elif isinstance(exc, InvalidTag): + original_exc = exc + exc = InvalidEncryptionKeyAPIError( + f"{self._log_name}: Invalid encryption key", self._server_name + ) + exc.__cause__ = original_exc super()._handle_error(exc) async def perform_handshake(self, timeout: float) -> None: @@ -159,43 +161,26 @@ class APINoiseFrameHelper(APIFrameHelper): if frame is None: return - try: - if self._state == NOISE_STATE_READY: - self._handle_frame(frame) - elif self._state == NOISE_STATE_HELLO: - self._handle_hello(frame) - elif self._state == NOISE_STATE_HANDSHAKE: - self._handle_handshake(frame) - else: - self._handle_closed(frame) - except Exception as err: # pylint: disable=broad-except - self._handle_error_and_close(err) - finally: - self._remove_from_buffer() + # asyncio already runs data_received in a try block + # which will call connection_lost if an exception is raised + if self._state == NOISE_STATE_READY: + self._handle_frame(frame) + elif self._state == NOISE_STATE_HELLO: + self._handle_hello(frame) + elif self._state == NOISE_STATE_HANDSHAKE: + self._handle_handshake(frame) + else: + self._handle_closed(frame) + + self._remove_from_buffer() def _send_hello_handshake(self) -> None: """Send a ClientHello to the server.""" - if TYPE_CHECKING: - assert self._writer is not None, "Writer is not set" - handshake_frame = b"\x00" + self._proto.write_message() frame_len = len(handshake_frame) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) hello_handshake = NOISE_HELLO + header + handshake_frame - - if self._debug_enabled(): - _LOGGER.debug( - "%s: Sending encrypted hello handshake: [%s]", - self._log_name, - hello_handshake.hex(), - ) - - try: - self._writer(hello_handshake) - except WRITE_EXCEPTIONS as err: - raise SocketAPIError( - f"{self._log_name}: Error while writing data: {err}" - ) from err + self._write_bytes(hello_handshake) def _handle_hello(self, server_hello: bytes) -> None: """Perform the handshake with the server.""" @@ -268,31 +253,25 @@ class APINoiseFrameHelper(APIFrameHelper): proto.start_handshake() self._proto = proto - def _handle_handshake(self, msg: bytes) -> None: - _LOGGER.debug("Starting handshake...") - if msg[0] != 0: - explanation = msg[1:].decode() - if explanation == "Handshake MAC failure": - self._handle_error_and_close( - InvalidEncryptionKeyAPIError( - f"{self._log_name}: Invalid encryption key", self._server_name - ) - ) - return + def _error_on_incorrect_preamble(self, msg: bytes) -> None: + """Handle an incorrect preamble.""" + explanation = msg[1:].decode() + if explanation == "Handshake MAC failure": self._handle_error_and_close( - HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}") + InvalidEncryptionKeyAPIError( + f"{self._log_name}: Invalid encryption key", self._server_name + ) ) return - try: - self._proto.read_message(msg[1:]) - except InvalidTag as invalid_tag_exc: - ex = InvalidEncryptionKeyAPIError( - f"{self._log_name}: Invalid encryption key", self._server_name - ) - ex.__cause__ = invalid_tag_exc - self._handle_error_and_close(ex) + self._handle_error_and_close( + HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}") + ) + + def _handle_handshake(self, msg: bytes) -> None: + if msg[0] != 0: + self._error_on_incorrect_preamble(msg) return - _LOGGER.debug("Handshake complete") + self._proto.read_message(msg[1:]) self._state = NOISE_STATE_READY noise_protocol = self._proto.noise_protocol self._decrypt = partial( @@ -315,10 +294,8 @@ class APINoiseFrameHelper(APIFrameHelper): if TYPE_CHECKING: assert self._encrypt is not None, "Handshake should be complete" - assert self._writer is not None, "Writer is not set" out: list[bytes] = [] - debug_enabled = self._debug_enabled() for packet in packets: type_: int = packet[0] data: bytes = packet[1] @@ -332,33 +309,18 @@ class APINoiseFrameHelper(APIFrameHelper): ) ) frame = self._encrypt(data_header + data) - - if debug_enabled is True: - _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex()) - frame_len = len(frame) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) out.append(header) out.append(frame) - try: - self._writer(b"".join(out)) - except WRITE_EXCEPTIONS as err: - raise SocketAPIError( - f"{self._log_name}: Error while writing data: {err}" - ) from err + self._write_bytes(b"".join(out)) def _handle_frame(self, frame: bytes) -> None: """Handle an incoming frame.""" if TYPE_CHECKING: assert self._decrypt is not None, "Handshake should be complete" - try: - msg = self._decrypt(frame) - except InvalidTag as ex: - self._handle_error_and_close( - ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}") - ) - return + msg = self._decrypt(frame) # Message layout is # 2 bytes: message type # 2 bytes: message length diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index ae26e71..4591a8b 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -5,7 +5,6 @@ from .base cimport APIFrameHelper cdef bint TYPE_CHECKING -cdef object WRITE_EXCEPTIONS cdef object bytes_to_varuint, varuint_to_bytes cpdef _varuint_to_bytes(cython.int value) @@ -33,6 +32,7 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper): @cython.locals( type_="unsigned int", data=bytes, - packet=tuple + packet=tuple, + type_=object ) cpdef write_packets(self, list packets) diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index bc20ce7..82a5b51 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -1,14 +1,11 @@ from __future__ import annotations import asyncio -import logging from functools import lru_cache from typing import TYPE_CHECKING -from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError -from .base import WRITE_EXCEPTIONS, APIFrameHelper - -_LOGGER = logging.getLogger(__name__) +from ..core import ProtocolAPIError, RequiresEncryptionAPIError +from .base import APIFrameHelper _int = int _bytes = bytes @@ -66,11 +63,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): The entire packet must be written in a single call. """ - if TYPE_CHECKING: - assert self._writer is not None, "Writer should be set" - out: list[bytes] = [] - debug_enabled = self._debug_enabled() for packet in packets: type_: int = packet[0] data: bytes = packet[1] @@ -78,17 +71,8 @@ class APIPlaintextFrameHelper(APIFrameHelper): out.append(varuint_to_bytes(len(data))) out.append(varuint_to_bytes(type_)) out.append(data) - if debug_enabled is True: - _LOGGER.debug( - "%s: Sending plaintext frame %s", self._log_name, data.hex() - ) - try: - self._writer(b"".join(out)) - except WRITE_EXCEPTIONS as err: - raise SocketAPIError( - f"{self._log_name}: Error while writing data: {err}" - ) from err + self._write_bytes(b"".join(out)) def data_received( # pylint: disable=too-many-branches,too-many-return-statements self, data: bytes | bytearray | memoryview diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index 1214eff..fca5191 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -186,13 +186,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): try: await self._cli.start_connection(on_stop=self._on_disconnect) except Exception as err: # pylint: disable=broad-except - self._async_set_connection_state_while_locked( - ReconnectLogicState.DISCONNECTED - ) - if self._on_connect_error_cb is not None: - await self._on_connect_error_cb(err) - self._async_log_connection_error(err) - self._tries += 1 + await self._handle_connection_failure(err) return False finish_connect_time = time.perf_counter() connect_time = finish_connect_time - start_connect_time @@ -204,18 +198,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): try: await self._cli.finish_connection(login=True) except Exception as err: # pylint: disable=broad-except - self._async_set_connection_state_while_locked( - ReconnectLogicState.DISCONNECTED - ) - if self._on_connect_error_cb is not None: - await self._on_connect_error_cb(err) - self._async_log_connection_error(err) - if isinstance(err, AUTH_EXCEPTIONS): - # If we get an encryption or password error, - # backoff for the maximum amount of time - self._tries = MAXIMUM_BACKOFF_TRIES - else: - self._tries += 1 + await self._handle_connection_failure(err) return False self._tries = 0 finish_handshake_time = time.perf_counter() @@ -227,6 +210,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): await self._on_connect_cb() return True + async def _handle_connection_failure(self, err: Exception) -> None: + """Handle a connection failure.""" + self._async_set_connection_state_while_locked(ReconnectLogicState.DISCONNECTED) + if self._on_connect_error_cb is not None: + await self._on_connect_error_cb(err) + self._async_log_connection_error(err) + if isinstance(err, AUTH_EXCEPTIONS): + # If we get an encryption or password error, + # backoff for the maximum amount of time + self._tries = MAXIMUM_BACKOFF_TRIES + else: + self._tries += 1 + def _schedule_connect(self, delay: float) -> None: """Schedule a connect attempt.""" self._cancel_connect("Scheduling new connect attempt") diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index adcd221..ae838fc 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -11,7 +11,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped] from aioesphomeapi import APIConnection from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper -from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint from aioesphomeapi._frame_helper.plain_text import ( @@ -27,6 +26,7 @@ from aioesphomeapi.core import ( InvalidEncryptionKeyAPIError, ProtocolAPIError, SocketAPIError, + SocketClosedAPIError, ) from .common import async_fire_time_changed, utcnow @@ -62,8 +62,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper): header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) try: self._writer(header + frame) - except WRITE_EXCEPTIONS as err: - raise SocketAPIError( + except (RuntimeError, ConnectionResetError, OSError) as err: + raise SocketClosedAPIError( f"{self._log_name}: Error while writing data: {err}" ) from err