Refactor frame helper to get debug state from connection (#679)
This commit is contained in:
parent
f1a9f4b452
commit
83b4f43610
|
@ -19,7 +19,6 @@ cdef class APIFrameHelper:
|
||||||
cdef unsigned int _pos
|
cdef unsigned int _pos
|
||||||
cdef object _client_info
|
cdef object _client_info
|
||||||
cdef str _log_name
|
cdef str _log_name
|
||||||
cdef object _debug_enabled
|
|
||||||
|
|
||||||
cpdef set_log_name(self, str log_name)
|
cpdef set_log_name(self, str log_name)
|
||||||
|
|
||||||
|
@ -32,6 +31,6 @@ cdef class APIFrameHelper:
|
||||||
@cython.locals(end_of_frame_pos="unsigned int")
|
@cython.locals(end_of_frame_pos="unsigned int")
|
||||||
cdef _remove_from_buffer(self)
|
cdef _remove_from_buffer(self)
|
||||||
|
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets, bint debug_enabled)
|
||||||
|
|
||||||
cdef _write_bytes(self, bytes data)
|
cdef _write_bytes(self, bytes data, bint debug_enabled)
|
||||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Callable, cast
|
from typing import TYPE_CHECKING, Callable, cast
|
||||||
|
|
||||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||||
|
@ -39,7 +38,6 @@ class APIFrameHelper:
|
||||||
"_pos",
|
"_pos",
|
||||||
"_client_info",
|
"_client_info",
|
||||||
"_log_name",
|
"_log_name",
|
||||||
"_debug_enabled",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -60,7 +58,6 @@ class APIFrameHelper:
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
self._client_info = client_info
|
self._client_info = client_info
|
||||||
self._log_name = log_name
|
self._log_name = log_name
|
||||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
|
||||||
|
|
||||||
def set_log_name(self, log_name: str) -> None:
|
def set_log_name(self, log_name: str) -> None:
|
||||||
"""Set the log name."""
|
"""Set the log name."""
|
||||||
|
@ -139,7 +136,9 @@ class APIFrameHelper:
|
||||||
handshake_handle.cancel()
|
handshake_handle.cancel()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
def write_packets(
|
||||||
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
) -> None:
|
||||||
"""Write a packets to the socket.
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
@ -181,9 +180,9 @@ class APIFrameHelper:
|
||||||
def resume_writing(self) -> None:
|
def resume_writing(self) -> None:
|
||||||
"""Stub."""
|
"""Stub."""
|
||||||
|
|
||||||
def _write_bytes(self, data: bytes) -> None:
|
def _write_bytes(self, data: bytes, debug_enabled: bool) -> None:
|
||||||
"""Write bytes to the socket."""
|
"""Write bytes to the socket."""
|
||||||
if self._debug_enabled() is True:
|
if debug_enabled:
|
||||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
@ -51,6 +51,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||||
frame_len=cython.uint,
|
frame_len=cython.uint,
|
||||||
type_=object
|
type_=object
|
||||||
)
|
)
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets, bint debug_enabled)
|
||||||
|
|
||||||
cdef _error_on_incorrect_preamble(self, bytes msg)
|
cdef _error_on_incorrect_preamble(self, bytes msg)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from struct import Struct
|
from struct import Struct
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
@ -20,7 +21,7 @@ from ..core import (
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
)
|
)
|
||||||
from .base import APIFrameHelper
|
from .base import _LOGGER, APIFrameHelper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..connection import APIConnection
|
from ..connection import APIConnection
|
||||||
|
@ -180,7 +181,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
frame_len = len(handshake_frame)
|
frame_len = len(handshake_frame)
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
hello_handshake = NOISE_HELLO + header + handshake_frame
|
||||||
self._write_bytes(hello_handshake)
|
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
|
||||||
|
|
||||||
def _handle_hello(self, server_hello: bytes) -> None:
|
def _handle_hello(self, server_hello: bytes) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Perform the handshake with the server."""
|
||||||
|
@ -284,7 +285,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
)
|
)
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
def write_packets(
|
||||||
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
) -> None:
|
||||||
"""Write a packets to the socket.
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
@ -314,7 +317,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
out.append(header)
|
out.append(header)
|
||||||
out.append(frame)
|
out.append(frame)
|
||||||
|
|
||||||
self._write_bytes(b"".join(out))
|
self._write_bytes(b"".join(out), debug_enabled)
|
||||||
|
|
||||||
def _handle_frame(self, frame: bytes) -> None:
|
def _handle_frame(self, frame: bytes) -> None:
|
||||||
"""Handle an incoming frame."""
|
"""Handle an incoming frame."""
|
||||||
|
|
|
@ -35,4 +35,4 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
packet=tuple,
|
packet=tuple,
|
||||||
type_=object
|
type_=object
|
||||||
)
|
)
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets, bint debug_enabled)
|
||||||
|
|
|
@ -56,7 +56,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
super().connection_made(transport)
|
super().connection_made(transport)
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
|
def write_packets(
|
||||||
|
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||||
|
) -> None:
|
||||||
"""Write a packets to the socket.
|
"""Write a packets to the socket.
|
||||||
|
|
||||||
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
Packets are in the format of tuple[protobuf_type, protobuf_data]
|
||||||
|
@ -72,7 +74,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
out.append(varuint_to_bytes(type_))
|
out.append(varuint_to_bytes(type_))
|
||||||
out.append(data)
|
out.append(data)
|
||||||
|
|
||||||
self._write_bytes(b"".join(out))
|
self._write_bytes(b"".join(out), debug_enabled)
|
||||||
|
|
||||||
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
|
||||||
self, data: bytes | bytearray | memoryview
|
self, data: bytes | bytearray | memoryview
|
||||||
|
|
|
@ -643,7 +643,7 @@ class APIConnection:
|
||||||
assert self._frame_helper is not None
|
assert self._frame_helper is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._frame_helper.write_packets(packets)
|
self._frame_helper.write_packets(packets, debug_enabled)
|
||||||
except SocketAPIError as err:
|
except SocketAPIError as err:
|
||||||
# If writing packet fails, we don't know what state the frames
|
# If writing packet fails, we don't know what state the frames
|
||||||
# are in anymore and we have to close the connection
|
# are in anymore and we have to close the connection
|
||||||
|
|
|
@ -8,6 +8,7 @@ from aioesphomeapi.api_pb2 import (
|
||||||
)
|
)
|
||||||
from aioesphomeapi.client import APIClient
|
from aioesphomeapi.client import APIClient
|
||||||
from aioesphomeapi.client_callbacks import on_ble_raw_advertisement_response
|
from aioesphomeapi.client_callbacks import on_ble_raw_advertisement_response
|
||||||
|
|
||||||
# cythonize -X language_level=3 -a -i aioesphomeapi/client_callbacks.py
|
# cythonize -X language_level=3 -a -i aioesphomeapi/client_callbacks.py
|
||||||
# cythonize -X language_level=3 -a -i aioesphomeapi/connection.py
|
# cythonize -X language_level=3 -a -i aioesphomeapi/connection.py
|
||||||
|
|
||||||
|
|
|
@ -541,7 +541,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
assert not writes
|
assert not writes
|
||||||
|
|
||||||
await handshake_task
|
await handshake_task
|
||||||
helper.write_packets([(1, b"to device")])
|
helper.write_packets([(1, b"to device")], True)
|
||||||
encrypted_packet = writes.pop()
|
encrypted_packet = writes.pop()
|
||||||
header = encrypted_packet[0:1]
|
header = encrypted_packet[0:1]
|
||||||
assert header == b"\x01"
|
assert header == b"\x01"
|
||||||
|
|
Loading…
Reference in New Issue