mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-28 04:27:27 +02:00
Refactor to reduce duplicate connection code (#644)
This commit is contained in:
parent
d350e96405
commit
c76d741cb8
@ -5,6 +5,7 @@ from ..connection cimport APIConnection
|
|||||||
|
|
||||||
|
|
||||||
cdef bint TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
|
cdef object WRITE_EXCEPTIONS
|
||||||
|
|
||||||
cdef class APIFrameHelper:
|
cdef class APIFrameHelper:
|
||||||
|
|
||||||
@ -32,3 +33,5 @@ cdef class APIFrameHelper:
|
|||||||
cdef _remove_from_buffer(self)
|
cdef _remove_from_buffer(self)
|
||||||
|
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets)
|
||||||
|
|
||||||
|
cdef _write_bytes(self, bytes data)
|
||||||
|
@ -180,3 +180,18 @@ class APIFrameHelper:
|
|||||||
|
|
||||||
def resume_writing(self) -> None:
|
def resume_writing(self) -> None:
|
||||||
"""Stub."""
|
"""Stub."""
|
||||||
|
|
||||||
|
def _write_bytes(self, data: bytes) -> None:
|
||||||
|
"""Write bytes to the socket."""
|
||||||
|
if self._debug_enabled() is True:
|
||||||
|
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._writer is not None, "Writer is not set"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._writer(data)
|
||||||
|
except WRITE_EXCEPTIONS as err:
|
||||||
|
raise SocketClosedAPIError(
|
||||||
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
|
) from err
|
||||||
|
@ -48,6 +48,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
packet=tuple,
|
packet=tuple,
|
||||||
data_len=cython.uint,
|
data_len=cython.uint,
|
||||||
frame=bytes,
|
frame=bytes,
|
||||||
frame_len=cython.uint
|
frame_len=cython.uint,
|
||||||
|
type_=object
|
||||||
)
|
)
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets)
|
||||||
|
|
||||||
|
cdef _error_on_incorrect_preamble(self, bytes msg)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from struct import Struct
|
from struct import Struct
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
@ -20,15 +19,12 @@ from ..core import (
|
|||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
SocketAPIError,
|
|
||||||
)
|
)
|
||||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
from .base import APIFrameHelper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..connection import APIConnection
|
from ..connection import APIConnection
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
||||||
|
|
||||||
@ -119,14 +115,20 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
def _handle_error(self, exc: Exception) -> None:
|
def _handle_error(self, exc: Exception) -> None:
|
||||||
"""Handle an error, and provide a good message when during hello."""
|
"""Handle an error, and provide a good message when during hello."""
|
||||||
if isinstance(exc, ConnectionResetError) and self._state == NOISE_STATE_HELLO:
|
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
|
||||||
original_exc = exc
|
original_exc: Exception = exc
|
||||||
exc = HandshakeAPIError(
|
exc = HandshakeAPIError(
|
||||||
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
||||||
"Try enabling encryption on the device or turning off "
|
"Try enabling encryption on the device or turning off "
|
||||||
f"encryption on the client ({self._client_info})."
|
f"encryption on the client ({self._client_info})."
|
||||||
)
|
)
|
||||||
exc.__cause__ = original_exc
|
exc.__cause__ = original_exc
|
||||||
|
elif isinstance(exc, InvalidTag):
|
||||||
|
original_exc = exc
|
||||||
|
exc = InvalidEncryptionKeyAPIError(
|
||||||
|
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||||
|
)
|
||||||
|
exc.__cause__ = original_exc
|
||||||
super()._handle_error(exc)
|
super()._handle_error(exc)
|
||||||
|
|
||||||
async def perform_handshake(self, timeout: float) -> None:
|
async def perform_handshake(self, timeout: float) -> None:
|
||||||
@ -159,7 +161,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
if frame is None:
|
if frame is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
# asyncio already runs data_received in a try block
|
||||||
|
# which will call connection_lost if an exception is raised
|
||||||
if self._state == NOISE_STATE_READY:
|
if self._state == NOISE_STATE_READY:
|
||||||
self._handle_frame(frame)
|
self._handle_frame(frame)
|
||||||
elif self._state == NOISE_STATE_HELLO:
|
elif self._state == NOISE_STATE_HELLO:
|
||||||
@ -168,34 +171,16 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._handle_handshake(frame)
|
self._handle_handshake(frame)
|
||||||
else:
|
else:
|
||||||
self._handle_closed(frame)
|
self._handle_closed(frame)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
self._handle_error_and_close(err)
|
|
||||||
finally:
|
|
||||||
self._remove_from_buffer()
|
self._remove_from_buffer()
|
||||||
|
|
||||||
def _send_hello_handshake(self) -> None:
|
def _send_hello_handshake(self) -> None:
|
||||||
"""Send a ClientHello to the server."""
|
"""Send a ClientHello to the server."""
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert self._writer is not None, "Writer is not set"
|
|
||||||
|
|
||||||
handshake_frame = b"\x00" + self._proto.write_message()
|
handshake_frame = b"\x00" + self._proto.write_message()
|
||||||
frame_len = len(handshake_frame)
|
frame_len = len(handshake_frame)
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
hello_handshake = NOISE_HELLO + header + handshake_frame
|
||||||
|
self._write_bytes(hello_handshake)
|
||||||
if self._debug_enabled():
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s: Sending encrypted hello handshake: [%s]",
|
|
||||||
self._log_name,
|
|
||||||
hello_handshake.hex(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._writer(hello_handshake)
|
|
||||||
except WRITE_EXCEPTIONS as err:
|
|
||||||
raise SocketAPIError(
|
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def _handle_hello(self, server_hello: bytes) -> None:
|
def _handle_hello(self, server_hello: bytes) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Perform the handshake with the server."""
|
||||||
@ -268,9 +253,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
proto.start_handshake()
|
proto.start_handshake()
|
||||||
self._proto = proto
|
self._proto = proto
|
||||||
|
|
||||||
def _handle_handshake(self, msg: bytes) -> None:
|
def _error_on_incorrect_preamble(self, msg: bytes) -> None:
|
||||||
_LOGGER.debug("Starting handshake...")
|
"""Handle an incorrect preamble."""
|
||||||
if msg[0] != 0:
|
|
||||||
explanation = msg[1:].decode()
|
explanation = msg[1:].decode()
|
||||||
if explanation == "Handshake MAC failure":
|
if explanation == "Handshake MAC failure":
|
||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
@ -282,17 +266,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._handle_error_and_close(
|
self._handle_error_and_close(
|
||||||
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
|
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_handshake(self, msg: bytes) -> None:
|
||||||
|
if msg[0] != 0:
|
||||||
|
self._error_on_incorrect_preamble(msg)
|
||||||
return
|
return
|
||||||
try:
|
|
||||||
self._proto.read_message(msg[1:])
|
self._proto.read_message(msg[1:])
|
||||||
except InvalidTag as invalid_tag_exc:
|
|
||||||
ex = InvalidEncryptionKeyAPIError(
|
|
||||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
|
||||||
)
|
|
||||||
ex.__cause__ = invalid_tag_exc
|
|
||||||
self._handle_error_and_close(ex)
|
|
||||||
return
|
|
||||||
_LOGGER.debug("Handshake complete")
|
|
||||||
self._state = NOISE_STATE_READY
|
self._state = NOISE_STATE_READY
|
||||||
noise_protocol = self._proto.noise_protocol
|
noise_protocol = self._proto.noise_protocol
|
||||||
self._decrypt = partial(
|
self._decrypt = partial(
|
||||||
@ -315,10 +294,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._encrypt is not None, "Handshake should be complete"
|
assert self._encrypt is not None, "Handshake should be complete"
|
||||||
assert self._writer is not None, "Writer is not set"
|
|
||||||
|
|
||||||
out: list[bytes] = []
|
out: list[bytes] = []
|
||||||
debug_enabled = self._debug_enabled()
|
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
type_: int = packet[0]
|
type_: int = packet[0]
|
||||||
data: bytes = packet[1]
|
data: bytes = packet[1]
|
||||||
@ -332,33 +309,18 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
frame = self._encrypt(data_header + data)
|
frame = self._encrypt(data_header + data)
|
||||||
|
|
||||||
if debug_enabled is True:
|
|
||||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
|
||||||
|
|
||||||
frame_len = len(frame)
|
frame_len = len(frame)
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
out.append(header)
|
out.append(header)
|
||||||
out.append(frame)
|
out.append(frame)
|
||||||
|
|
||||||
try:
|
self._write_bytes(b"".join(out))
|
||||||
self._writer(b"".join(out))
|
|
||||||
except WRITE_EXCEPTIONS as err:
|
|
||||||
raise SocketAPIError(
|
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def _handle_frame(self, frame: bytes) -> None:
|
def _handle_frame(self, frame: bytes) -> None:
|
||||||
"""Handle an incoming frame."""
|
"""Handle an incoming frame."""
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._decrypt is not None, "Handshake should be complete"
|
assert self._decrypt is not None, "Handshake should be complete"
|
||||||
try:
|
|
||||||
msg = self._decrypt(frame)
|
msg = self._decrypt(frame)
|
||||||
except InvalidTag as ex:
|
|
||||||
self._handle_error_and_close(
|
|
||||||
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# Message layout is
|
# Message layout is
|
||||||
# 2 bytes: message type
|
# 2 bytes: message type
|
||||||
# 2 bytes: message length
|
# 2 bytes: message length
|
||||||
|
@ -5,7 +5,6 @@ from .base cimport APIFrameHelper
|
|||||||
|
|
||||||
|
|
||||||
cdef bint TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
cdef object WRITE_EXCEPTIONS
|
|
||||||
cdef object bytes_to_varuint, varuint_to_bytes
|
cdef object bytes_to_varuint, varuint_to_bytes
|
||||||
|
|
||||||
cpdef _varuint_to_bytes(cython.int value)
|
cpdef _varuint_to_bytes(cython.int value)
|
||||||
@ -33,6 +32,7 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
@cython.locals(
|
@cython.locals(
|
||||||
type_="unsigned int",
|
type_="unsigned int",
|
||||||
data=bytes,
|
data=bytes,
|
||||||
packet=tuple
|
packet=tuple,
|
||||||
|
type_=object
|
||||||
)
|
)
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets)
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
|
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
||||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
from .base import APIFrameHelper
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_int = int
|
_int = int
|
||||||
_bytes = bytes
|
_bytes = bytes
|
||||||
@ -66,11 +63,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
The entire packet must be written in a single call.
|
The entire packet must be written in a single call.
|
||||||
"""
|
"""
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert self._writer is not None, "Writer should be set"
|
|
||||||
|
|
||||||
out: list[bytes] = []
|
out: list[bytes] = []
|
||||||
debug_enabled = self._debug_enabled()
|
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
type_: int = packet[0]
|
type_: int = packet[0]
|
||||||
data: bytes = packet[1]
|
data: bytes = packet[1]
|
||||||
@ -78,17 +71,8 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
out.append(varuint_to_bytes(len(data)))
|
out.append(varuint_to_bytes(len(data)))
|
||||||
out.append(varuint_to_bytes(type_))
|
out.append(varuint_to_bytes(type_))
|
||||||
out.append(data)
|
out.append(data)
|
||||||
if debug_enabled is True:
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s: Sending plaintext frame %s", self._log_name, data.hex()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
self._write_bytes(b"".join(out))
|
||||||
self._writer(b"".join(out))
|
|
||||||
except WRITE_EXCEPTIONS as err:
|
|
||||||
raise SocketAPIError(
|
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
||||||
self, data: bytes | bytearray | memoryview
|
self, data: bytes | bytearray | memoryview
|
||||||
|
@ -186,13 +186,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
try:
|
try:
|
||||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self._async_set_connection_state_while_locked(
|
await self._handle_connection_failure(err)
|
||||||
ReconnectLogicState.DISCONNECTED
|
|
||||||
)
|
|
||||||
if self._on_connect_error_cb is not None:
|
|
||||||
await self._on_connect_error_cb(err)
|
|
||||||
self._async_log_connection_error(err)
|
|
||||||
self._tries += 1
|
|
||||||
return False
|
return False
|
||||||
finish_connect_time = time.perf_counter()
|
finish_connect_time = time.perf_counter()
|
||||||
connect_time = finish_connect_time - start_connect_time
|
connect_time = finish_connect_time - start_connect_time
|
||||||
@ -204,18 +198,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
try:
|
try:
|
||||||
await self._cli.finish_connection(login=True)
|
await self._cli.finish_connection(login=True)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self._async_set_connection_state_while_locked(
|
await self._handle_connection_failure(err)
|
||||||
ReconnectLogicState.DISCONNECTED
|
|
||||||
)
|
|
||||||
if self._on_connect_error_cb is not None:
|
|
||||||
await self._on_connect_error_cb(err)
|
|
||||||
self._async_log_connection_error(err)
|
|
||||||
if isinstance(err, AUTH_EXCEPTIONS):
|
|
||||||
# If we get an encryption or password error,
|
|
||||||
# backoff for the maximum amount of time
|
|
||||||
self._tries = MAXIMUM_BACKOFF_TRIES
|
|
||||||
else:
|
|
||||||
self._tries += 1
|
|
||||||
return False
|
return False
|
||||||
self._tries = 0
|
self._tries = 0
|
||||||
finish_handshake_time = time.perf_counter()
|
finish_handshake_time = time.perf_counter()
|
||||||
@ -227,6 +210,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||||||
await self._on_connect_cb()
|
await self._on_connect_cb()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _handle_connection_failure(self, err: Exception) -> None:
|
||||||
|
"""Handle a connection failure."""
|
||||||
|
self._async_set_connection_state_while_locked(ReconnectLogicState.DISCONNECTED)
|
||||||
|
if self._on_connect_error_cb is not None:
|
||||||
|
await self._on_connect_error_cb(err)
|
||||||
|
self._async_log_connection_error(err)
|
||||||
|
if isinstance(err, AUTH_EXCEPTIONS):
|
||||||
|
# If we get an encryption or password error,
|
||||||
|
# backoff for the maximum amount of time
|
||||||
|
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||||
|
else:
|
||||||
|
self._tries += 1
|
||||||
|
|
||||||
def _schedule_connect(self, delay: float) -> None:
|
def _schedule_connect(self, delay: float) -> None:
|
||||||
"""Schedule a connect attempt."""
|
"""Schedule a connect attempt."""
|
||||||
self._cancel_connect("Scheduling new connect attempt")
|
self._cancel_connect("Scheduling new connect attempt")
|
||||||
|
@ -11,7 +11,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
|||||||
|
|
||||||
from aioesphomeapi import APIConnection
|
from aioesphomeapi import APIConnection
|
||||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
|
||||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||||
from aioesphomeapi._frame_helper.plain_text import (
|
from aioesphomeapi._frame_helper.plain_text import (
|
||||||
@ -27,6 +26,7 @@ from aioesphomeapi.core import (
|
|||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
SocketAPIError,
|
SocketAPIError,
|
||||||
|
SocketClosedAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import async_fire_time_changed, utcnow
|
from .common import async_fire_time_changed, utcnow
|
||||||
@ -62,8 +62,8 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
|||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
try:
|
try:
|
||||||
self._writer(header + frame)
|
self._writer(header + frame)
|
||||||
except WRITE_EXCEPTIONS as err:
|
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||||
raise SocketAPIError(
|
raise SocketClosedAPIError(
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user