Avoid expensive debug logging checks in packet processor (#700)

This commit is contained in:
J. Nick Koston 2023-11-25 07:51:48 -06:00 committed by GitHub
parent 0d25cc92a0
commit 67661dbd7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 64 deletions

View File

@ -168,7 +168,17 @@ def _stringify_or_none(value: str | None) -> str | None:
# pylint: disable=too-many-public-methods
class APIClient:
"""The ESPHome API client.
This class is the main entrypoint for interacting with the API.
It is recommended to use this class in combination with the
ReconnectLogic class to automatically reconnect to the device
if the connection is lost.
"""
__slots__ = (
"_debug_enabled",
"_params",
"_connection",
"cached_name",
@ -205,6 +215,7 @@ class APIClient:
Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP.
"""
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams(
address=str(address),
port=port,
@ -223,6 +234,12 @@ class APIClient:
self._on_stop_task: asyncio.Task[None] | None = None
self._set_log_name()
def set_debug(self, enabled: bool) -> None:
"""Enable debug logging."""
self._debug_enabled = enabled
if self._connection:
self._connection.set_debug(enabled)
@property
def zeroconf_manager(self) -> ZeroconfManager:
return self._params.zeroconf_manager
@ -299,7 +316,10 @@ class APIClient:
raise APIConnectionError(f"Already connected to {self.log_name}!")
self._connection = APIConnection(
self._params, partial(self._on_stop, on_stop), log_name=self.log_name
self._params,
partial(self._on_stop, on_stop),
self._debug_enabled,
self.log_name,
)
try:
@ -556,7 +576,6 @@ class APIClient:
has_cache: bool = False,
address_type: int | None = None,
) -> Callable[[], None]:
debug = _LOGGER.isEnabledFor(logging.DEBUG)
connect_future: asyncio.Future[None] = self._loop.create_future()
if has_cache:
@ -570,7 +589,7 @@ class APIClient:
# of the connection. This can crash the esp if the service list is too large.
request_type = BluetoothDeviceRequestType.CONNECT
if debug:
if self._debug_enabled:
_LOGGER.debug("%s: Using connection version %s", address, request_type)
unsub = self._get_connection().send_message_callback_response(
@ -604,7 +623,7 @@ class APIClient:
# the slot is recovered before the timeout is raised
# to avoid race were we run out even though we have a slot.
addr = to_human_readable_address(address)
if debug:
if self._debug_enabled:
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
disconnect_timed_out = (
not await self._bluetooth_device_disconnect_guard_timeout(
@ -640,7 +659,7 @@ class APIClient:
try:
await self.bluetooth_device_disconnect(address, timeout=timeout)
except TimeoutAPIError:
if _LOGGER.isEnabledFor(logging.DEBUG):
if self._debug_enabled:
_LOGGER.debug(
"%s: Disconnect timed out: %s",
to_human_readable_address(address),

View File

@ -79,7 +79,7 @@ cdef class APIConnection:
cdef bint _send_pending_ping
cdef public bint is_connected
cdef bint _handshake_complete
cdef object _debug_enabled
cdef bint _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info

View File

@ -128,6 +128,8 @@ class APIConnection:
An instance of this class may only be used once, for every new connection
a new instance should be established.
This class should only be created from APIClient and should not be used directly.
"""
__slots__ = (
@ -161,7 +163,8 @@ class APIConnection:
self,
params: ConnectionParams,
on_stop: Callable[[bool], None],
log_name: str | None = None,
debug_enabled: bool,
log_name: str | None,
) -> None:
self._params = params
self.on_stop: Callable[[bool], None] | None = on_stop
@ -195,7 +198,7 @@ class APIConnection:
self._loop = asyncio.get_event_loop()
self.is_connected = False
self._handshake_complete = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
@ -214,7 +217,8 @@ class APIConnection:
return
was_connected = self.is_connected
self._set_connection_state(ConnectionState.CLOSED)
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
if self._debug_enabled:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
for fut in self._read_exception_futures:
if fut.done():
continue
@ -262,6 +266,10 @@ class APIConnection:
self.on_stop = None
on_stop(self._expected_disconnect)
def set_debug(self, enable: bool) -> None:
"""Enable or disable debug logging."""
self._debug_enabled = enable
async def _connect_resolve_host(self) -> hr.AddrInfo:
"""Step 1 in connect process: resolve the address."""
try:
@ -278,7 +286,6 @@ class APIConnection:
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
"""Step 2 in connect process: connect the socket."""
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
self._socket = sock
sock.setblocking(False)
@ -294,7 +301,7 @@ class APIConnection:
err,
)
if debug_enable is True:
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
@ -312,7 +319,7 @@ class APIConnection:
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
if debug_enable is True:
if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
@ -403,13 +410,14 @@ class APIConnection:
def _process_hello_resp(self, resp: HelloResponse) -> None:
"""Process a HelloResponse."""
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self.log_name,
resp.server_info,
resp.api_version_major,
resp.api_version_minor,
)
if self._debug_enabled:
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self.log_name,
resp.server_info,
resp.api_version_major,
resp.api_version_minor,
)
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if api_version.major > 2:
_LOGGER.error(
@ -456,7 +464,7 @@ class APIConnection:
self._pong_timer = loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received
)
elif self._debug_enabled() is True:
elif self._debug_enabled:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
@ -485,11 +493,12 @@ class APIConnection:
"""Ping not received."""
if not self.is_connected:
return
_LOGGER.debug(
"%s: Ping response not received after %s seconds",
self.log_name,
self._keep_alive_timeout,
)
if self._debug_enabled:
_LOGGER.debug(
"%s: Ping response not received after %s seconds",
self.log_name,
self._keep_alive_timeout,
)
self.report_fatal_error(
PingFailedAPIError(
f"Ping response not received after {self._keep_alive_timeout} seconds"
@ -608,14 +617,14 @@ class APIConnection:
)
packets: list[tuple[int, bytes]] = []
debug_enabled = self._debug_enabled()
debug_enabled = self._debug_enabled
for msg in msgs:
msg_type = type(msg)
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")
if debug_enabled is True:
if debug_enabled:
_LOGGER.debug(
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
)
@ -786,12 +795,14 @@ class APIConnection:
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet."""
debug_enabled = self._debug_enabled
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
_LOGGER.debug(
"%s: Skipping message type %s",
self.log_name,
msg_type_proto,
)
if debug_enabled:
_LOGGER.debug(
"%s: Skipping unknown message type %s",
self.log_name,
msg_type_proto,
)
return
try:
@ -818,7 +829,7 @@ class APIConnection:
msg_type = type(msg)
if self._debug_enabled() is True:
if debug_enabled:
_LOGGER.debug(
"%s: Got message of type %s: %s",
self.log_name,
@ -891,10 +902,11 @@ class APIConnection:
self._fatal_exception = TimeoutAPIError(
"Timed out waiting to finish connect before disconnecting"
)
_LOGGER.debug(
"%s: Connect task didn't finish before disconnect",
self.log_name,
)
if self._debug_enabled:
_LOGGER.debug(
"%s: Connect task didn't finish before disconnect",
self.log_name,
)
self._expected_disconnect = True
if self._handshake_complete:

View File

@ -78,18 +78,18 @@ def noise_connection_params() -> ConnectionParams:
)
async def on_stop(expected_disconnect: bool) -> None:
def on_stop(expected_disconnect: bool) -> None:
pass
@pytest.fixture
def conn(connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(connection_params, on_stop)
return PatchableAPIConnection(connection_params, on_stop, True, None)
@pytest.fixture
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(noise_connection_params, on_stop)
return PatchableAPIConnection(noise_connection_params, on_stop, True, None)
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")

View File

@ -124,7 +124,9 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import itertools
import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
@ -1068,29 +1069,6 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
@pytest.mark.asyncio
async def test_bluetooth_gatt_read_descriptor(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test bluetooth_gatt_read_descriptor."""
client, connection, transport, protocol = api_client
read_task = asyncio.create_task(client.bluetooth_gatt_read_descriptor(1234, 1234))
await asyncio.sleep(0)
other_response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=4567, data=b"4567"
)
mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=1234, data=b"1234"
)
mock_data_received(protocol, generate_plaintext_packet(response))
assert await read_task == b"1234"
@pytest.mark.asyncio
async def test_bluetooth_gatt_get_services(
api_client: tuple[
@ -1374,3 +1352,37 @@ async def test_subscribe_service_calls(auth_client: APIClient) -> None:
service_msg = HomeassistantServiceResponse(service="bob")
await send(service_msg)
on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg))
@pytest.mark.asyncio
async def test_set_debug(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test set_debug."""
client, connection, transport, protocol = api_client
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
caplog.set_level(logging.DEBUG)
client.set_debug(True)
assert client.log_name == "mydevice.local"
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
mock_data_received(protocol, generate_plaintext_packet(response))
await device_info_task
assert "My Device" in caplog.text
caplog.clear()
client.set_debug(False)
device_info_task = asyncio.create_task(client.device_info())
await asyncio.sleep(0)
mock_data_received(protocol, generate_plaintext_packet(response))
await device_info_task
assert "My Device" not in caplog.text

View File

@ -1,15 +1,18 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Coroutine
from datetime import timedelta
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from google.protobuf import message
from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
DeviceInfoResponse,
DisconnectRequest,
@ -491,6 +494,7 @@ async def test_force_disconnect_fails(
with patch.object(protocol, "_writer", side_effect=OSError):
await conn.force_disconnect()
assert "Failed to send (forced) disconnect request" in caplog.text
await asyncio.sleep(0)
@pytest.mark.asyncio
@ -702,3 +706,35 @@ async def test_respond_to_ping_request(
ping_response_bytes = b"\x00\x00\x08"
assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_response_bytes)]
@pytest.mark.asyncio
async def test_unknown_protobuf_message_type_logged(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test unknown protobuf messages are logged but do not cause the connection to collapse."""
client, connection, transport, protocol = api_client
response: message.Message = DeviceInfoResponse(
name="realname",
friendly_name="My Device",
has_deep_sleep=True,
)
caplog.set_level(logging.DEBUG)
client.set_debug(True)
bytes_ = response.SerializeToString()
message_with_invalid_protobuf_number = (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(16385)
+ bytes_
)
mock_data_received(protocol, message_with_invalid_protobuf_number)
assert "Skipping unknown message type 16385" in caplog.text
assert connection.is_connected
await connection.force_disconnect()
await asyncio.sleep(0)