mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-23 21:51:34 +01:00
Refactor to reduce duplicate connection code (#644)
This commit is contained in:
parent
d350e96405
commit
c76d741cb8
@ -5,6 +5,7 @@ from ..connection cimport APIConnection
|
||||
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
cdef object WRITE_EXCEPTIONS
|
||||
|
||||
cdef class APIFrameHelper:
|
||||
|
||||
@ -32,3 +33,5 @@ cdef class APIFrameHelper:
|
||||
cdef _remove_from_buffer(self)
|
||||
|
||||
cpdef write_packets(self, list packets)
|
||||
|
||||
cdef _write_bytes(self, bytes data)
|
||||
|
@ -180,3 +180,18 @@ class APIFrameHelper:
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""Stub."""
|
||||
|
||||
def _write_bytes(self, data: bytes) -> None:
|
||||
"""Write bytes to the socket."""
|
||||
if self._debug_enabled() is True:
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
try:
|
||||
self._writer(data)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketClosedAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
@ -48,6 +48,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
packet=tuple,
|
||||
data_len=cython.uint,
|
||||
frame=bytes,
|
||||
frame_len=cython.uint
|
||||
frame_len=cython.uint,
|
||||
type_=object
|
||||
)
|
||||
cpdef write_packets(self, list packets)
|
||||
|
||||
cdef _error_on_incorrect_preamble(self, bytes msg)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import logging
|
||||
from functools import partial
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
@ -20,15 +19,12 @@ from ..core import (
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||
from .base import APIFrameHelper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import APIConnection
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
||||
|
||||
@ -119,14 +115,20 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
"""Handle an error, and provide a good message when during hello."""
|
||||
if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO:
|
||||
original_exc = exc
|
||||
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
|
||||
original_exc: Exception = exc
|
||||
exc = HandshakeAPIError(
|
||||
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
||||
"Try enabling encryption on the device or turning off "
|
||||
f"encryption on the client ({self._client_info})."
|
||||
)
|
||||
exc.__cause__ = original_exc
|
||||
elif isinstance(exc, InvalidTag):
|
||||
original_exc = exc
|
||||
exc = InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
exc.__cause__ = original_exc
|
||||
super()._handle_error(exc)
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
@ -159,43 +161,26 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if frame is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self._state == NOISE_STATE_READY:
|
||||
self._handle_frame(frame)
|
||||
elif self._state == NOISE_STATE_HELLO:
|
||||
self._handle_hello(frame)
|
||||
elif self._state == NOISE_STATE_HANDSHAKE:
|
||||
self._handle_handshake(frame)
|
||||
else:
|
||||
self._handle_closed(frame)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._handle_error_and_close(err)
|
||||
finally:
|
||||
self._remove_from_buffer()
|
||||
# asyncio already runs data_received in a try block
|
||||
# which will call connection_lost if an exception is raised
|
||||
if self._state == NOISE_STATE_READY:
|
||||
self._handle_frame(frame)
|
||||
elif self._state == NOISE_STATE_HELLO:
|
||||
self._handle_hello(frame)
|
||||
elif self._state == NOISE_STATE_HANDSHAKE:
|
||||
self._handle_handshake(frame)
|
||||
else:
|
||||
self._handle_closed(frame)
|
||||
|
||||
self._remove_from_buffer()
|
||||
|
||||
def _send_hello_handshake(self) -> None:
|
||||
"""Send a ClientHello to the server."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
handshake_frame = b"\x00" + self._proto.write_message()
|
||||
frame_len = len(handshake_frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
||||
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug(
|
||||
"%s: Sending encrypted hello handshake: [%s]",
|
||||
self._log_name,
|
||||
hello_handshake.hex(),
|
||||
)
|
||||
|
||||
try:
|
||||
self._writer(hello_handshake)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
self._write_bytes(hello_handshake)
|
||||
|
||||
def _handle_hello(self, server_hello: bytes) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
@ -268,31 +253,25 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
proto.start_handshake()
|
||||
self._proto = proto
|
||||
|
||||
def _handle_handshake(self, msg: bytes) -> None:
|
||||
_LOGGER.debug("Starting handshake...")
|
||||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
self._handle_error_and_close(
|
||||
InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
)
|
||||
return
|
||||
def _error_on_incorrect_preamble(self, msg: bytes) -> None:
|
||||
"""Handle an incorrect preamble."""
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
|
||||
InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._proto.read_message(msg[1:])
|
||||
except InvalidTag as invalid_tag_exc:
|
||||
ex = InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
ex.__cause__ = invalid_tag_exc
|
||||
self._handle_error_and_close(ex)
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
|
||||
)
|
||||
|
||||
def _handle_handshake(self, msg: bytes) -> None:
|
||||
if msg[0] != 0:
|
||||
self._error_on_incorrect_preamble(msg)
|
||||
return
|
||||
_LOGGER.debug("Handshake complete")
|
||||
self._proto.read_message(msg[1:])
|
||||
self._state = NOISE_STATE_READY
|
||||
noise_protocol = self._proto.noise_protocol
|
||||
self._decrypt = partial(
|
||||
@ -315,10 +294,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
out: list[bytes] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
for packet in packets:
|
||||
type_: int = packet[0]
|
||||
data: bytes = packet[1]
|
||||
@ -332,33 +309,18 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
)
|
||||
frame = self._encrypt(data_header + data)
|
||||
|
||||
if debug_enabled is True:
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
out.append(header)
|
||||
out.append(frame)
|
||||
|
||||
try:
|
||||
self._writer(b"".join(out))
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
self._write_bytes(b"".join(out))
|
||||
|
||||
def _handle_frame(self, frame: bytes) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._decrypt is not None, "Handshake should be complete"
|
||||
try:
|
||||
msg = self._decrypt(frame)
|
||||
except InvalidTag as ex:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
||||
)
|
||||
return
|
||||
msg = self._decrypt(frame)
|
||||
# Message layout is
|
||||
# 2 bytes: message type
|
||||
# 2 bytes: message length
|
||||
|
@ -5,7 +5,6 @@ from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
cdef object WRITE_EXCEPTIONS
|
||||
cdef object bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
cpdef _varuint_to_bytes(cython.int value)
|
||||
@ -33,6 +32,7 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
data=bytes,
|
||||
packet=tuple
|
||||
packet=tuple,
|
||||
type_=object
|
||||
)
|
||||
cpdef write_packets(self, list packets)
|
||||
|
@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
|
||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
||||
from .base import APIFrameHelper
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
@ -66,11 +63,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
out: list[bytes] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
for packet in packets:
|
||||
type_: int = packet[0]
|
||||
data: bytes = packet[1]
|
||||
@ -78,17 +71,8 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
out.append(varuint_to_bytes(len(data)))
|
||||
out.append(varuint_to_bytes(type_))
|
||||
out.append(data)
|
||||
if debug_enabled is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Sending plaintext frame %s", self._log_name, data.hex()
|
||||
)
|
||||
|
||||
try:
|
||||
self._writer(b"".join(out))
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
self._write_bytes(b"".join(out))
|
||||
|
||||
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
||||
self, data: bytes | bytearray | memoryview
|
||||
|
@ -186,13 +186,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
try:
|
||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._async_set_connection_state_while_locked(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
self._tries += 1
|
||||
await self._handle_connection_failure(err)
|
||||
return False
|
||||
finish_connect_time = time.perf_counter()
|
||||
connect_time = finish_connect_time - start_connect_time
|
||||
@ -204,18 +198,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
try:
|
||||
await self._cli.finish_connection(login=True)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._async_set_connection_state_while_locked(
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
if isinstance(err, AUTH_EXCEPTIONS):
|
||||
# If we get an encryption or password error,
|
||||
# backoff for the maximum amount of time
|
||||
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||
else:
|
||||
self._tries += 1
|
||||
await self._handle_connection_failure(err)
|
||||
return False
|
||||
self._tries = 0
|
||||
finish_handshake_time = time.perf_counter()
|
||||
@ -227,6 +210,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
await self._on_connect_cb()
|
||||
return True
|
||||
|
||||
async def _handle_connection_failure(self, err: Exception) -> None:
|
||||
"""Handle a connection failure."""
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.DISCONNECTED)
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
if isinstance(err, AUTH_EXCEPTIONS):
|
||||
# If we get an encryption or password error,
|
||||
# backoff for the maximum amount of time
|
||||
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||
else:
|
||||
self._tries += 1
|
||||
|
||||
def _schedule_connect(self, delay: float) -> None:
|
||||
"""Schedule a connect attempt."""
|
||||
self._cancel_connect("Scheduling new connect attempt")
|
||||
|
@ -11,7 +11,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
|
||||
from aioesphomeapi import APIConnection
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||
from aioesphomeapi._frame_helper.plain_text import (
|
||||
@ -27,6 +26,7 @@ from aioesphomeapi.core import (
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
@ -62,8 +62,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
try:
|
||||
self._writer(header + frame)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketClosedAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user