Refactor to reduce duplicate connection code (#644)

This commit is contained in:
J. Nick Koston 2023-11-18 15:10:40 -06:00 committed by GitHub
parent d350e96405
commit c76d741cb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 84 additions and 121 deletions

View File

@ -5,6 +5,7 @@ from ..connection cimport APIConnection
cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef class APIFrameHelper:
@ -32,3 +33,5 @@ cdef class APIFrameHelper:
cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets)
cdef _write_bytes(self, bytes data)

View File

@ -180,3 +180,18 @@ class APIFrameHelper:
def resume_writing(self) -> None:
"""Stub."""
def _write_bytes(self, data: bytes) -> None:
"""Write bytes to the socket."""
if self._debug_enabled() is True:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
try:
self._writer(data)
except WRITE_EXCEPTIONS as err:
raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err

View File

@ -48,6 +48,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
packet=tuple,
data_len=cython.uint,
frame=bytes,
frame_len=cython.uint
frame_len=cython.uint,
type_=object
)
cpdef write_packets(self, list packets)
cdef _error_on_incorrect_preamble(self, bytes msg)

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import binascii
import logging
from functools import partial
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
@ -20,15 +19,12 @@ from ..core import (
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
SocketAPIError,
)
from .base import WRITE_EXCEPTIONS, APIFrameHelper
from .base import APIFrameHelper
if TYPE_CHECKING:
from ..connection import APIConnection
_LOGGER = logging.getLogger(__name__)
PACK_NONCE = partial(Struct("<LQ").pack, 0)
@ -119,14 +115,20 @@ class APINoiseFrameHelper(APIFrameHelper):
def _handle_error(self, exc: Exception) -> None:
"""Handle an error, and provide a good message when during hello."""
if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO:
original_exc = exc
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
original_exc: Exception = exc
exc = HandshakeAPIError(
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
"Try enabling encryption on the device or turning off "
f"encryption on the client ({self._client_info})."
)
exc.__cause__ = original_exc
elif isinstance(exc, InvalidTag):
original_exc = exc
exc = InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
exc.__cause__ = original_exc
super()._handle_error(exc)
async def perform_handshake(self, timeout: float) -> None:
@ -159,43 +161,26 @@ class APINoiseFrameHelper(APIFrameHelper):
if frame is None:
return
try:
if self._state == NOISE_STATE_READY:
self._handle_frame(frame)
elif self._state == NOISE_STATE_HELLO:
self._handle_hello(frame)
elif self._state == NOISE_STATE_HANDSHAKE:
self._handle_handshake(frame)
else:
self._handle_closed(frame)
except Exception as err: # pylint: disable=broad-except
self._handle_error_and_close(err)
finally:
self._remove_from_buffer()
# asyncio already runs data_received in a try block
# which will call connection_lost if an exception is raised
if self._state == NOISE_STATE_READY:
self._handle_frame(frame)
elif self._state == NOISE_STATE_HELLO:
self._handle_hello(frame)
elif self._state == NOISE_STATE_HANDSHAKE:
self._handle_handshake(frame)
else:
self._handle_closed(frame)
self._remove_from_buffer()
def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server."""
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
handshake_frame = b"\x00" + self._proto.write_message()
frame_len = len(handshake_frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame
if self._debug_enabled():
_LOGGER.debug(
"%s: Sending encrypted hello handshake: [%s]",
self._log_name,
hello_handshake.hex(),
)
try:
self._writer(hello_handshake)
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
self._write_bytes(hello_handshake)
def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server."""
@ -268,31 +253,25 @@ class APINoiseFrameHelper(APIFrameHelper):
proto.start_handshake()
self._proto = proto
def _handle_handshake(self, msg: bytes) -> None:
_LOGGER.debug("Starting handshake...")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
self._handle_error_and_close(
InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
)
return
def _error_on_incorrect_preamble(self, msg: bytes) -> None:
"""Handle an incorrect preamble."""
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
self._handle_error_and_close(
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
)
return
try:
self._proto.read_message(msg[1:])
except InvalidTag as invalid_tag_exc:
ex = InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
ex.__cause__ = invalid_tag_exc
self._handle_error_and_close(ex)
self._handle_error_and_close(
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
)
def _handle_handshake(self, msg: bytes) -> None:
if msg[0] != 0:
self._error_on_incorrect_preamble(msg)
return
_LOGGER.debug("Handshake complete")
self._proto.read_message(msg[1:])
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
@ -315,10 +294,8 @@ class APINoiseFrameHelper(APIFrameHelper):
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
assert self._writer is not None, "Writer is not set"
out: list[bytes] = []
debug_enabled = self._debug_enabled()
for packet in packets:
type_: int = packet[0]
data: bytes = packet[1]
@ -332,33 +309,18 @@ class APINoiseFrameHelper(APIFrameHelper):
)
)
frame = self._encrypt(data_header + data)
if debug_enabled is True:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
out.append(header)
out.append(frame)
try:
self._writer(b"".join(out))
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
self._write_bytes(b"".join(out))
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
try:
msg = self._decrypt(frame)
except InvalidTag as ex:
self._handle_error_and_close(
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
)
return
msg = self._decrypt(frame)
# Message layout is
# 2 bytes: message type
# 2 bytes: message length

View File

@ -5,7 +5,6 @@ from .base cimport APIFrameHelper
cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes
cpdef _varuint_to_bytes(cython.int value)
@ -33,6 +32,7 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals(
type_="unsigned int",
data=bytes,
packet=tuple
packet=tuple,
type_=object
)
cpdef write_packets(self, list packets)

View File

@ -1,14 +1,11 @@
from __future__ import annotations
import asyncio
import logging
from functools import lru_cache
from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
from .base import WRITE_EXCEPTIONS, APIFrameHelper
_LOGGER = logging.getLogger(__name__)
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
from .base import APIFrameHelper
_int = int
_bytes = bytes
@ -66,11 +63,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
The entire packet must be written in a single call.
"""
if TYPE_CHECKING:
assert self._writer is not None, "Writer should be set"
out: list[bytes] = []
debug_enabled = self._debug_enabled()
for packet in packets:
type_: int = packet[0]
data: bytes = packet[1]
@ -78,17 +71,8 @@ class APIPlaintextFrameHelper(APIFrameHelper):
out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_))
out.append(data)
if debug_enabled is True:
_LOGGER.debug(
"%s: Sending plaintext frame %s", self._log_name, data.hex()
)
try:
self._writer(b"".join(out))
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
self._write_bytes(b"".join(out))
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
self, data: bytes | bytearray | memoryview

View File

@ -186,13 +186,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
try:
await self._cli.start_connection(on_stop=self._on_disconnect)
except Exception as err: # pylint: disable=broad-except
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
self._tries += 1
await self._handle_connection_failure(err)
return False
finish_connect_time = time.perf_counter()
connect_time = finish_connect_time - start_connect_time
@ -204,18 +198,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
try:
await self._cli.finish_connection(login=True)
except Exception as err: # pylint: disable=broad-except
self._async_set_connection_state_while_locked(
ReconnectLogicState.DISCONNECTED
)
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
if isinstance(err, AUTH_EXCEPTIONS):
# If we get an encryption or password error,
# backoff for the maximum amount of time
self._tries = MAXIMUM_BACKOFF_TRIES
else:
self._tries += 1
await self._handle_connection_failure(err)
return False
self._tries = 0
finish_handshake_time = time.perf_counter()
@ -227,6 +210,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
await self._on_connect_cb()
return True
async def _handle_connection_failure(self, err: Exception) -> None:
"""Handle a connection failure."""
self._async_set_connection_state_while_locked(ReconnectLogicState.DISCONNECTED)
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
if isinstance(err, AUTH_EXCEPTIONS):
# If we get an encryption or password error,
# backoff for the maximum amount of time
self._tries = MAXIMUM_BACKOFF_TRIES
else:
self._tries += 1
def _schedule_connect(self, delay: float) -> None:
"""Schedule a connect attempt."""
self._cancel_connect("Scheduling new connect attempt")

View File

@ -11,7 +11,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import (
@ -27,6 +26,7 @@ from aioesphomeapi.core import (
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
SocketAPIError,
SocketClosedAPIError,
)
from .common import async_fire_time_changed, utcnow
@ -62,8 +62,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
except WRITE_EXCEPTIONS as err:
raise SocketAPIError(
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err