Avoid expensive debug logging checks in packet processor (#700)

This commit is contained in:
J. Nick Koston 2023-11-25 07:51:48 -06:00 committed by GitHub
parent 0d25cc92a0
commit 67661dbd7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 64 deletions

View File

@ -168,7 +168,17 @@ def _stringify_or_none(value: str | None) -> str | None:
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class APIClient: class APIClient:
"""The ESPHome API client.
This class is the main entrypoint for interacting with the API.
It is recommended to use this class in combination with the
ReconnectLogic class to automatically reconnect to the device
if the connection is lost.
"""
__slots__ = ( __slots__ = (
"_debug_enabled",
"_params", "_params",
"_connection", "_connection",
"cached_name", "cached_name",
@ -205,6 +215,7 @@ class APIClient:
Can be used to prevent accidentally connecting to a different device if Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP. IP passed as address but DHCP reassigned IP.
""" """
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams( self._params = ConnectionParams(
address=str(address), address=str(address),
port=port, port=port,
@ -223,6 +234,12 @@ class APIClient:
self._on_stop_task: asyncio.Task[None] | None = None self._on_stop_task: asyncio.Task[None] | None = None
self._set_log_name() self._set_log_name()
def set_debug(self, enabled: bool) -> None:
"""Enable debug logging."""
self._debug_enabled = enabled
if self._connection:
self._connection.set_debug(enabled)
@property @property
def zeroconf_manager(self) -> ZeroconfManager: def zeroconf_manager(self) -> ZeroconfManager:
return self._params.zeroconf_manager return self._params.zeroconf_manager
@ -299,7 +316,10 @@ class APIClient:
raise APIConnectionError(f"Already connected to {self.log_name}!") raise APIConnectionError(f"Already connected to {self.log_name}!")
self._connection = APIConnection( self._connection = APIConnection(
self._params, partial(self._on_stop, on_stop), log_name=self.log_name self._params,
partial(self._on_stop, on_stop),
self._debug_enabled,
self.log_name,
) )
try: try:
@ -556,7 +576,6 @@ class APIClient:
has_cache: bool = False, has_cache: bool = False,
address_type: int | None = None, address_type: int | None = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
debug = _LOGGER.isEnabledFor(logging.DEBUG)
connect_future: asyncio.Future[None] = self._loop.create_future() connect_future: asyncio.Future[None] = self._loop.create_future()
if has_cache: if has_cache:
@ -570,7 +589,7 @@ class APIClient:
# of the connection. This can crash the esp if the service list is too large. # of the connection. This can crash the esp if the service list is too large.
request_type = BluetoothDeviceRequestType.CONNECT request_type = BluetoothDeviceRequestType.CONNECT
if debug: if self._debug_enabled:
_LOGGER.debug("%s: Using connection version %s", address, request_type) _LOGGER.debug("%s: Using connection version %s", address, request_type)
unsub = self._get_connection().send_message_callback_response( unsub = self._get_connection().send_message_callback_response(
@ -604,7 +623,7 @@ class APIClient:
# the slot is recovered before the timeout is raised # the slot is recovered before the timeout is raised
# to avoid race were we run out even though we have a slot. # to avoid race were we run out even though we have a slot.
addr = to_human_readable_address(address) addr = to_human_readable_address(address)
if debug: if self._debug_enabled:
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr) _LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
disconnect_timed_out = ( disconnect_timed_out = (
not await self._bluetooth_device_disconnect_guard_timeout( not await self._bluetooth_device_disconnect_guard_timeout(
@ -640,7 +659,7 @@ class APIClient:
try: try:
await self.bluetooth_device_disconnect(address, timeout=timeout) await self.bluetooth_device_disconnect(address, timeout=timeout)
except TimeoutAPIError: except TimeoutAPIError:
if _LOGGER.isEnabledFor(logging.DEBUG): if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Disconnect timed out: %s", "%s: Disconnect timed out: %s",
to_human_readable_address(address), to_human_readable_address(address),

View File

@ -79,7 +79,7 @@ cdef class APIConnection:
cdef bint _send_pending_ping cdef bint _send_pending_ping
cdef public bint is_connected cdef public bint is_connected
cdef bint _handshake_complete cdef bint _handshake_complete
cdef object _debug_enabled cdef bint _debug_enabled
cdef public str received_name cdef public str received_name
cdef public object resolved_addr_info cdef public object resolved_addr_info

View File

@ -128,6 +128,8 @@ class APIConnection:
An instance of this class may only be used once, for every new connection An instance of this class may only be used once, for every new connection
a new instance should be established. a new instance should be established.
This class should only be created from APIClient and should not be used directly.
""" """
__slots__ = ( __slots__ = (
@ -161,7 +163,8 @@ class APIConnection:
self, self,
params: ConnectionParams, params: ConnectionParams,
on_stop: Callable[[bool], None], on_stop: Callable[[bool], None],
log_name: str | None = None, debug_enabled: bool,
log_name: str | None,
) -> None: ) -> None:
self._params = params self._params = params
self.on_stop: Callable[[bool], None] | None = on_stop self.on_stop: Callable[[bool], None] | None = on_stop
@ -195,7 +198,7 @@ class APIConnection:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self.is_connected = False self.is_connected = False
self._handshake_complete = False self._handshake_complete = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) self._debug_enabled = debug_enabled
self.received_name: str = "" self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None self.resolved_addr_info: hr.AddrInfo | None = None
@ -214,7 +217,8 @@ class APIConnection:
return return
was_connected = self.is_connected was_connected = self.is_connected
self._set_connection_state(ConnectionState.CLOSED) self._set_connection_state(ConnectionState.CLOSED)
_LOGGER.debug("Cleaning up connection to %s", self.log_name) if self._debug_enabled:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
for fut in self._read_exception_futures: for fut in self._read_exception_futures:
if fut.done(): if fut.done():
continue continue
@ -262,6 +266,10 @@ class APIConnection:
self.on_stop = None self.on_stop = None
on_stop(self._expected_disconnect) on_stop(self._expected_disconnect)
def set_debug(self, enable: bool) -> None:
"""Enable or disable debug logging."""
self._debug_enabled = enable
async def _connect_resolve_host(self) -> hr.AddrInfo: async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address.""" """Step 1 in connect process: resolve the address."""
try: try:
@ -278,7 +286,6 @@ class APIConnection:
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None: async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
"""Step 2 in connect process: connect the socket.""" """Step 2 in connect process: connect the socket."""
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto) sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
self._socket = sock self._socket = sock
sock.setblocking(False) sock.setblocking(False)
@ -294,7 +301,7 @@ class APIConnection:
err, err,
) )
if debug_enable is True: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Connecting to %s:%s (%s)", "%s: Connecting to %s:%s (%s)",
self.log_name, self.log_name,
@ -312,7 +319,7 @@ class APIConnection:
except OSError as err: except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
if debug_enable is True: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Opened socket to %s:%s (%s)", "%s: Opened socket to %s:%s (%s)",
self.log_name, self.log_name,
@ -403,13 +410,14 @@ class APIConnection:
def _process_hello_resp(self, resp: HelloResponse) -> None: def _process_hello_resp(self, resp: HelloResponse) -> None:
"""Process a HelloResponse.""" """Process a HelloResponse."""
_LOGGER.debug( if self._debug_enabled:
"%s: Successfully connected ('%s' API=%s.%s)", _LOGGER.debug(
self.log_name, "%s: Successfully connected ('%s' API=%s.%s)",
resp.server_info, self.log_name,
resp.api_version_major, resp.server_info,
resp.api_version_minor, resp.api_version_major,
) resp.api_version_minor,
)
api_version = APIVersion(resp.api_version_major, resp.api_version_minor) api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if api_version.major > 2: if api_version.major > 2:
_LOGGER.error( _LOGGER.error(
@ -456,7 +464,7 @@ class APIConnection:
self._pong_timer = loop.call_at( self._pong_timer = loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received now + self._keep_alive_timeout, self._async_pong_not_received
) )
elif self._debug_enabled() is True: elif self._debug_enabled:
# #
# We haven't reached the ping response (pong) timeout yet # We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping # and we haven't seen a response to the last ping
@ -485,11 +493,12 @@ class APIConnection:
"""Ping not received.""" """Ping not received."""
if not self.is_connected: if not self.is_connected:
return return
_LOGGER.debug( if self._debug_enabled:
"%s: Ping response not received after %s seconds", _LOGGER.debug(
self.log_name, "%s: Ping response not received after %s seconds",
self._keep_alive_timeout, self.log_name,
) self._keep_alive_timeout,
)
self.report_fatal_error( self.report_fatal_error(
PingFailedAPIError( PingFailedAPIError(
f"Ping response not received after {self._keep_alive_timeout} seconds" f"Ping response not received after {self._keep_alive_timeout} seconds"
@ -608,14 +617,14 @@ class APIConnection:
) )
packets: list[tuple[int, bytes]] = [] packets: list[tuple[int, bytes]] = []
debug_enabled = self._debug_enabled() debug_enabled = self._debug_enabled
for msg in msgs: for msg in msgs:
msg_type = type(msg) msg_type = type(msg)
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}") raise ValueError(f"Message type id not found for type {msg_type}")
if debug_enabled is True: if debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg "%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
) )
@ -786,12 +795,14 @@ class APIConnection:
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet.""" """Process an incoming packet."""
debug_enabled = self._debug_enabled
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
_LOGGER.debug( if debug_enabled:
"%s: Skipping message type %s", _LOGGER.debug(
self.log_name, "%s: Skipping unknown message type %s",
msg_type_proto, self.log_name,
) msg_type_proto,
)
return return
try: try:
@ -818,7 +829,7 @@ class APIConnection:
msg_type = type(msg) msg_type = type(msg)
if self._debug_enabled() is True: if debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", "%s: Got message of type %s: %s",
self.log_name, self.log_name,
@ -891,10 +902,11 @@ class APIConnection:
self._fatal_exception = TimeoutAPIError( self._fatal_exception = TimeoutAPIError(
"Timed out waiting to finish connect before disconnecting" "Timed out waiting to finish connect before disconnecting"
) )
_LOGGER.debug( if self._debug_enabled:
"%s: Connect task didn't finish before disconnect", _LOGGER.debug(
self.log_name, "%s: Connect task didn't finish before disconnect",
) self.log_name,
)
self._expected_disconnect = True self._expected_disconnect = True
if self._handshake_complete: if self._handshake_complete:

View File

@ -78,18 +78,18 @@ def noise_connection_params() -> ConnectionParams:
) )
async def on_stop(expected_disconnect: bool) -> None: def on_stop(expected_disconnect: bool) -> None:
pass pass
@pytest.fixture @pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection: def conn(connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(connection_params, on_stop) return PatchableAPIConnection(connection_params, on_stop, True, None)
@pytest.fixture @pytest.fixture
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection: def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(noise_connection_params, on_stop) return PatchableAPIConnection(noise_connection_params, on_stop, True, None)
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login") @pytest_asyncio.fixture(name="plaintext_connect_task_no_login")

View File

@ -124,7 +124,9 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
class MockConnection(APIConnection): class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args.""" """Swallow args."""
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs) super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes): def process_packet(self, type_: int, data: bytes):
packets.append((type_, data)) packets.append((type_, data))

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import itertools import itertools
import logging
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
@ -1068,29 +1069,6 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0) await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
@pytest.mark.asyncio
async def test_bluetooth_gatt_read_descriptor(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_gatt_read_descriptor."""
client, connection, transport, protocol = api_client
read_task = asyncio.create_task(client.bluetooth_gatt_read_descriptor(1234, 1234))
await asyncio.sleep(0)
other_response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=4567, data=b"4567"
)
mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=1234, data=b"1234"
)
mock_data_received(protocol, generate_plaintext_packet(response))
assert await read_task == b"1234"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bluetooth_gatt_get_services( async def test_bluetooth_gatt_get_services(
api_client: tuple[ api_client: tuple[
@ -1374,3 +1352,37 @@ async def test_subscribe_service_calls(auth_client: APIClient) -> None:
service_msg = HomeassistantServiceResponse(service="bob") service_msg = HomeassistantServiceResponse(service="bob")
await send(service_msg) await send(service_msg)
on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg)) on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg))
@pytest.mark.asyncio
async def test_set_debug(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test set_debug."""
client, connection, transport, protocol = api_client
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
caplog.set_level(logging.DEBUG)
client.set_debug(True)
assert client.log_name == "mydevice.local"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
mock_data_received(protocol, generate_plaintext_packet(response))
await device_info_task
assert "My Device" in caplog.text
caplog.clear()
client.set_debug(False)
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
mock_data_received(protocol, generate_plaintext_packet(response))
await device_info_task
assert "My Device" not in caplog.text

View File

@ -1,15 +1,18 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest import pytest
from google.protobuf import message
from aioesphomeapi import APIClient from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ( from aioesphomeapi.api_pb2 import (
DeviceInfoResponse, DeviceInfoResponse,
DisconnectRequest, DisconnectRequest,
@ -491,6 +494,7 @@ async def test_force_disconnect_fails(
with patch.object(protocol, "_writer", side_effect=OSError): with patch.object(protocol, "_writer", side_effect=OSError):
await conn.force_disconnect() await conn.force_disconnect()
assert "Failed to send (forced) disconnect request" in caplog.text assert "Failed to send (forced) disconnect request" in caplog.text
await asyncio.sleep(0)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -702,3 +706,35 @@ async def test_respond_to_ping_request(
ping_response_bytes = b"\x00\x00\x08" ping_response_bytes = b"\x00\x00\x08"
assert transport.write.call_count == 1 assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_response_bytes)] assert transport.write.mock_calls == [call(ping_response_bytes)]
@pytest.mark.asyncio
async def test_unknown_protobuf_message_type_logged(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test unknown protobuf messages are logged but do not cause the connection to collapse."""
client, connection, transport, protocol = api_client
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
caplog.set_level(logging.DEBUG)
client.set_debug(True)
bytes_ = response.SerializeToString()
message_with_invalid_protobuf_number = (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(16385)
+ bytes_
)
mock_data_received(protocol, message_with_invalid_protobuf_number)
assert "Skipping unknown message type 16385" in caplog.text
assert connection.is_connected
await connection.force_disconnect()
await asyncio.sleep(0)