Refactor frame helper to get debug state from connection (#679)

This commit is contained in:
J. Nick Koston 2023-11-23 19:20:52 +01:00 committed by GitHub
parent f1a9f4b452
commit 83b4f43610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 23 additions and 19 deletions

View File

@ -19,7 +19,6 @@ cdef class APIFrameHelper:
cdef unsigned int _pos
cdef object _client_info
cdef str _log_name
cdef object _debug_enabled
cpdef set_log_name(self, str log_name)
@ -32,6 +31,6 @@ cdef class APIFrameHelper:
@cython.locals(end_of_frame_pos="unsigned int")
cdef _remove_from_buffer(self)
cpdef write_packets(self, list packets)
cpdef write_packets(self, list packets, bint debug_enabled)
cdef _write_bytes(self, bytes data)
cdef _write_bytes(self, bytes data, bint debug_enabled)

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import logging
from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Callable, cast
from ..core import HandshakeAPIError, SocketClosedAPIError
@ -39,7 +38,6 @@ class APIFrameHelper:
"_pos",
"_client_info",
"_log_name",
"_debug_enabled",
)
def __init__(
@ -60,7 +58,6 @@ class APIFrameHelper:
self._pos = 0
self._client_info = client_info
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def set_log_name(self, log_name: str) -> None:
"""Set the log name."""
@ -139,7 +136,9 @@ class APIFrameHelper:
handshake_handle.cancel()
@abstractmethod
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -181,9 +180,9 @@ class APIFrameHelper:
def resume_writing(self) -> None:
"""Stub."""
def _write_bytes(self, data: bytes) -> None:
def _write_bytes(self, data: bytes, debug_enabled: bool) -> None:
"""Write bytes to the socket."""
if self._debug_enabled() is True:
if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
if TYPE_CHECKING:

View File

@ -51,6 +51,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
frame_len=cython.uint,
type_=object
)
cpdef write_packets(self, list packets)
cpdef write_packets(self, list packets, bint debug_enabled)
cdef _error_on_incorrect_preamble(self, bytes msg)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import binascii
import logging
from functools import partial
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
@ -20,7 +21,7 @@ from ..core import (
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
)
from .base import APIFrameHelper
from .base import _LOGGER, APIFrameHelper
if TYPE_CHECKING:
from ..connection import APIConnection
@ -180,7 +181,7 @@ class APINoiseFrameHelper(APIFrameHelper):
frame_len = len(handshake_frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame
self._write_bytes(hello_handshake)
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server."""
@ -284,7 +285,9 @@ class APINoiseFrameHelper(APIFrameHelper):
)
self._ready_future.set_result(None)
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -314,7 +317,7 @@ class APINoiseFrameHelper(APIFrameHelper):
out.append(header)
out.append(frame)
self._write_bytes(b"".join(out))
self._write_bytes(b"".join(out), debug_enabled)
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""

View File

@ -35,4 +35,4 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper):
packet=tuple,
type_=object
)
cpdef write_packets(self, list packets)
cpdef write_packets(self, list packets, bint debug_enabled)

View File

@ -56,7 +56,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
super().connection_made(transport)
self._ready_future.set_result(None)
def write_packets(self, packets: list[tuple[int, bytes]]) -> None:
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
@ -72,7 +74,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
out.append(varuint_to_bytes(type_))
out.append(data)
self._write_bytes(b"".join(out))
self._write_bytes(b"".join(out), debug_enabled)
def data_received( # pylint: disable=too-many-branches,too-many-return-statements
self, data: bytes | bytearray | memoryview

View File

@ -643,7 +643,7 @@ class APIConnection:
assert self._frame_helper is not None
try:
self._frame_helper.write_packets(packets)
self._frame_helper.write_packets(packets, debug_enabled)
except SocketAPIError as err:
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection

View File

@ -8,6 +8,7 @@ from aioesphomeapi.api_pb2 import (
)
from aioesphomeapi.client import APIClient
from aioesphomeapi.client_callbacks import on_ble_raw_advertisement_response
# cythonize -X language_level=3 -a -i aioesphomeapi/client_callbacks.py
# cythonize -X language_level=3 -a -i aioesphomeapi/connection.py

View File

@ -541,7 +541,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
assert not writes
await handshake_task
helper.write_packets([(1, b"to device")])
helper.write_packets([(1, b"to device")], True)
encrypted_packet = writes.pop()
header = encrypted_packet[0:1]
assert header == b"\x01"