Refactor frame helper to get debug state from connection (#679)

This commit is contained in:
J. Nick Koston 2023-11-23 19:20:52 +01:00 committed by GitHub
parent f1a9f4b452
commit 83b4f43610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 23 additions and 19 deletions

View File

@ -19,7 +19,6 @@ cdef class APIFrameHelper:
cdef unsigned int _pos cdef unsigned int _pos
cdef object _client_info cdef object _client_info
cdef str _log_name cdef str _log_name
cdef object _debug_enabled
cpdef set_log_name(self, str log_name) cpdef set_log_name(self, str log_name)
@ -32,6 +31,6 @@ cdef class APIFrameHelper:
@cython.locals(end_of_frame_pos="unsigned int") @cython.locals(end_of_frame_pos="unsigned int")
cdef _remove_from_buffer(self) cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets) cpdef write_packets(self, list packets, bint debug_enabled)
cdef _write_bytes(self, bytes data) cdef _write_bytes(self, bytes data, bint debug_enabled)

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Callable, cast from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError from ..core import HandshakeAPIError, SocketClosedAPIError
@ -39,7 +38,6 @@ class APIFrameHelper:
"_pos", "_pos",
"_client_info", "_client_info",
"_log_name", "_log_name",
"_debug_enabled",
) )
def __init__( def __init__(
@ -60,7 +58,6 @@ class APIFrameHelper:
self._pos = 0 self._pos = 0
self._client_info = client_info self._client_info = client_info
self._log_name = log_name self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def set_log_name(self, log_name: str) -> None: def set_log_name(self, log_name: str) -> None:
"""Set the log name.""" """Set the log name."""
@ -139,7 +136,9 @@ class APIFrameHelper:
handshake_handle.cancel() handshake_handle.cancel()
@abstractmethod @abstractmethod
def write_packets(self, packets: list[tuple[int, bytes]]) -> None: def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket. """Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data] Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -181,9 +180,9 @@ class APIFrameHelper:
def resume_writing(self) -> None: def resume_writing(self) -> None:
"""Stub.""" """Stub."""
def _write_bytes(self, data: bytes) -> None: def _write_bytes(self, data: bytes, debug_enabled: bool) -> None:
"""Write bytes to the socket.""" """Write bytes to the socket."""
if self._debug_enabled() is True: if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex()) _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -51,6 +51,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
frame_len=cython.uint, frame_len=cython.uint,
type_=object type_=object
) )
cpdef write_packets(self, list packets) cpdef write_packets(self, list packets, bint debug_enabled)
cdef _error_on_incorrect_preamble(self, bytes msg) cdef _error_on_incorrect_preamble(self, bytes msg)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import binascii import binascii
import logging
from functools import partial from functools import partial
from struct import Struct from struct import Struct
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@ -20,7 +21,7 @@ from ..core import (
InvalidEncryptionKeyAPIError, InvalidEncryptionKeyAPIError,
ProtocolAPIError, ProtocolAPIError,
) )
from .base import APIFrameHelper from .base import _LOGGER, APIFrameHelper
if TYPE_CHECKING: if TYPE_CHECKING:
from ..connection import APIConnection from ..connection import APIConnection
@ -180,7 +181,7 @@ class APINoiseFrameHelper(APIFrameHelper):
frame_len = len(handshake_frame) frame_len = len(handshake_frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame hello_handshake = NOISE_HELLO + header + handshake_frame
self._write_bytes(hello_handshake) self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
def _handle_hello(self, server_hello: bytes) -> None: def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server.""" """Perform the handshake with the server."""
@ -284,7 +285,9 @@ class APINoiseFrameHelper(APIFrameHelper):
) )
self._ready_future.set_result(None) self._ready_future.set_result(None)
def write_packets(self, packets: list[tuple[int, bytes]]) -> None: def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket. """Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data] Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -314,7 +317,7 @@ class APINoiseFrameHelper(APIFrameHelper):
out.append(header) out.append(header)
out.append(frame) out.append(frame)
self._write_bytes(b"".join(out)) self._write_bytes(b"".join(out), debug_enabled)
def _handle_frame(self, frame: bytes) -> None: def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame.""" """Handle an incoming frame."""

View File

@ -35,4 +35,4 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
packet=tuple, packet=tuple,
type_=object type_=object
) )
cpdef write_packets(self, list packets) cpdef write_packets(self, list packets, bint debug_enabled)

View File

@ -56,7 +56,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
super().connection_made(transport) super().connection_made(transport)
self._ready_future.set_result(None) self._ready_future.set_result(None)
def write_packets(self, packets: list[tuple[int, bytes]]) -> None: def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket. """Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data] Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -72,7 +74,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
out.append(varuint_to_bytes(type_)) out.append(varuint_to_bytes(type_))
out.append(data) out.append(data)
self._write_bytes(b"".join(out)) self._write_bytes(b"".join(out), debug_enabled)
def data_received( # pylint: disable=too-many-branches,too-many-return-statements def data_received( # pylint: disable=too-many-branches,too-many-return-statements
self, data: bytes | bytearray | memoryview self, data: bytes | bytearray | memoryview

View File

@ -643,7 +643,7 @@ class APIConnection:
assert self._frame_helper is not None assert self._frame_helper is not None
try: try:
self._frame_helper.write_packets(packets) self._frame_helper.write_packets(packets, debug_enabled)
except SocketAPIError as err: except SocketAPIError as err:
# If writing packet fails, we don't know what state the frames # If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection # are in anymore and we have to close the connection

View File

@ -8,6 +8,7 @@ from aioesphomeapi.api_pb2 import (
) )
from aioesphomeapi.client import APIClient from aioesphomeapi.client import APIClient
from aioesphomeapi.client_callbacks import on_ble_raw_advertisement_response from aioesphomeapi.client_callbacks import on_ble_raw_advertisement_response
# cythonize -X language_level=3 -a -i aioesphomeapi/client_callbacks.py # cythonize -X language_level=3 -a -i aioesphomeapi/client_callbacks.py
# cythonize -X language_level=3 -a -i aioesphomeapi/connection.py # cythonize -X language_level=3 -a -i aioesphomeapi/connection.py

View File

@ -541,7 +541,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert not writes assert not writes
await handshake_task await handshake_task
helper.write_packets([(1, b"to device")]) helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop() encrypted_packet = writes.pop()
header = encrypted_packet[0:1] header = encrypted_packet[0:1]
assert header == b"\x01" assert header == b"\x01"