mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-17 01:51:23 +01:00
Avoid expensive debug logging checks in packet processor (#700)
This commit is contained in:
parent
0d25cc92a0
commit
67661dbd7f
@ -168,7 +168,17 @@ def _stringify_or_none(value: str | None) -> str | None:
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class APIClient:
|
||||
"""The ESPHome API client.
|
||||
|
||||
This class is the main entrypoint for interacting with the API.
|
||||
|
||||
It is recommended to use this class in combination with the
|
||||
ReconnectLogic class to automatically reconnect to the device
|
||||
if the connection is lost.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_debug_enabled",
|
||||
"_params",
|
||||
"_connection",
|
||||
"cached_name",
|
||||
@ -205,6 +215,7 @@ class APIClient:
|
||||
Can be used to prevent accidentally connecting to a different device if
|
||||
IP passed as address but DHCP reassigned IP.
|
||||
"""
|
||||
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
self._params = ConnectionParams(
|
||||
address=str(address),
|
||||
port=port,
|
||||
@ -223,6 +234,12 @@ class APIClient:
|
||||
self._on_stop_task: asyncio.Task[None] | None = None
|
||||
self._set_log_name()
|
||||
|
||||
def set_debug(self, enabled: bool) -> None:
|
||||
"""Enable debug logging."""
|
||||
self._debug_enabled = enabled
|
||||
if self._connection:
|
||||
self._connection.set_debug(enabled)
|
||||
|
||||
@property
|
||||
def zeroconf_manager(self) -> ZeroconfManager:
|
||||
return self._params.zeroconf_manager
|
||||
@ -299,7 +316,10 @@ class APIClient:
|
||||
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
||||
|
||||
self._connection = APIConnection(
|
||||
self._params, partial(self._on_stop, on_stop), log_name=self.log_name
|
||||
self._params,
|
||||
partial(self._on_stop, on_stop),
|
||||
self._debug_enabled,
|
||||
self.log_name,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -556,7 +576,6 @@ class APIClient:
|
||||
has_cache: bool = False,
|
||||
address_type: int | None = None,
|
||||
) -> Callable[[], None]:
|
||||
debug = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
connect_future: asyncio.Future[None] = self._loop.create_future()
|
||||
|
||||
if has_cache:
|
||||
@ -570,7 +589,7 @@ class APIClient:
|
||||
# of the connection. This can crash the esp if the service list is too large.
|
||||
request_type = BluetoothDeviceRequestType.CONNECT
|
||||
|
||||
if debug:
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug("%s: Using connection version %s", address, request_type)
|
||||
|
||||
unsub = self._get_connection().send_message_callback_response(
|
||||
@ -604,7 +623,7 @@ class APIClient:
|
||||
# the slot is recovered before the timeout is raised
|
||||
# to avoid race were we run out even though we have a slot.
|
||||
addr = to_human_readable_address(address)
|
||||
if debug:
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
|
||||
disconnect_timed_out = (
|
||||
not await self._bluetooth_device_disconnect_guard_timeout(
|
||||
@ -640,7 +659,7 @@ class APIClient:
|
||||
try:
|
||||
await self.bluetooth_device_disconnect(address, timeout=timeout)
|
||||
except TimeoutAPIError:
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Disconnect timed out: %s",
|
||||
to_human_readable_address(address),
|
||||
|
@ -79,7 +79,7 @@ cdef class APIConnection:
|
||||
cdef bint _send_pending_ping
|
||||
cdef public bint is_connected
|
||||
cdef bint _handshake_complete
|
||||
cdef object _debug_enabled
|
||||
cdef bint _debug_enabled
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
|
||||
|
@ -128,6 +128,8 @@ class APIConnection:
|
||||
|
||||
An instance of this class may only be used once, for every new connection
|
||||
a new instance should be established.
|
||||
|
||||
This class should only be created from APIClient and should not be used directly.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
@ -161,7 +163,8 @@ class APIConnection:
|
||||
self,
|
||||
params: ConnectionParams,
|
||||
on_stop: Callable[[bool], None],
|
||||
log_name: str | None = None,
|
||||
debug_enabled: bool,
|
||||
log_name: str | None,
|
||||
) -> None:
|
||||
self._params = params
|
||||
self.on_stop: Callable[[bool], None] | None = on_stop
|
||||
@ -195,7 +198,7 @@ class APIConnection:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self.is_connected = False
|
||||
self._handshake_complete = False
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
self._debug_enabled = debug_enabled
|
||||
self.received_name: str = ""
|
||||
self.resolved_addr_info: hr.AddrInfo | None = None
|
||||
|
||||
@ -214,7 +217,8 @@ class APIConnection:
|
||||
return
|
||||
was_connected = self.is_connected
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||
for fut in self._read_exception_futures:
|
||||
if fut.done():
|
||||
continue
|
||||
@ -262,6 +266,10 @@ class APIConnection:
|
||||
self.on_stop = None
|
||||
on_stop(self._expected_disconnect)
|
||||
|
||||
def set_debug(self, enable: bool) -> None:
|
||||
"""Enable or disable debug logging."""
|
||||
self._debug_enabled = enable
|
||||
|
||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||
"""Step 1 in connect process: resolve the address."""
|
||||
try:
|
||||
@ -278,7 +286,6 @@ class APIConnection:
|
||||
|
||||
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
debug_enable = self._debug_enabled()
|
||||
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
||||
self._socket = sock
|
||||
sock.setblocking(False)
|
||||
@ -294,7 +301,7 @@ class APIConnection:
|
||||
err,
|
||||
)
|
||||
|
||||
if debug_enable is True:
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
@ -312,7 +319,7 @@ class APIConnection:
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||
|
||||
if debug_enable is True:
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
self.log_name,
|
||||
@ -403,13 +410,14 @@ class APIConnection:
|
||||
|
||||
def _process_hello_resp(self, resp: HelloResponse) -> None:
|
||||
"""Process a HelloResponse."""
|
||||
_LOGGER.debug(
|
||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||
self.log_name,
|
||||
resp.server_info,
|
||||
resp.api_version_major,
|
||||
resp.api_version_minor,
|
||||
)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||
self.log_name,
|
||||
resp.server_info,
|
||||
resp.api_version_major,
|
||||
resp.api_version_minor,
|
||||
)
|
||||
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
||||
if api_version.major > 2:
|
||||
_LOGGER.error(
|
||||
@ -456,7 +464,7 @@ class APIConnection:
|
||||
self._pong_timer = loop.call_at(
|
||||
now + self._keep_alive_timeout, self._async_pong_not_received
|
||||
)
|
||||
elif self._debug_enabled() is True:
|
||||
elif self._debug_enabled:
|
||||
#
|
||||
# We haven't reached the ping response (pong) timeout yet
|
||||
# and we haven't seen a response to the last ping
|
||||
@ -485,11 +493,12 @@ class APIConnection:
|
||||
"""Ping not received."""
|
||||
if not self.is_connected:
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Ping response not received after %s seconds",
|
||||
self.log_name,
|
||||
self._keep_alive_timeout,
|
||||
)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Ping response not received after %s seconds",
|
||||
self.log_name,
|
||||
self._keep_alive_timeout,
|
||||
)
|
||||
self.report_fatal_error(
|
||||
PingFailedAPIError(
|
||||
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
||||
@ -608,14 +617,14 @@ class APIConnection:
|
||||
)
|
||||
|
||||
packets: list[tuple[int, bytes]] = []
|
||||
debug_enabled = self._debug_enabled()
|
||||
debug_enabled = self._debug_enabled
|
||||
|
||||
for msg in msgs:
|
||||
msg_type = type(msg)
|
||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||
|
||||
if debug_enabled is True:
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
|
||||
)
|
||||
@ -786,12 +795,14 @@ class APIConnection:
|
||||
|
||||
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||
"""Process an incoming packet."""
|
||||
debug_enabled = self._debug_enabled
|
||||
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s",
|
||||
self.log_name,
|
||||
msg_type_proto,
|
||||
)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping unknown message type %s",
|
||||
self.log_name,
|
||||
msg_type_proto,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@ -818,7 +829,7 @@ class APIConnection:
|
||||
|
||||
msg_type = type(msg)
|
||||
|
||||
if self._debug_enabled() is True:
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s",
|
||||
self.log_name,
|
||||
@ -891,10 +902,11 @@ class APIConnection:
|
||||
self._fatal_exception = TimeoutAPIError(
|
||||
"Timed out waiting to finish connect before disconnecting"
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"%s: Connect task didn't finish before disconnect",
|
||||
self.log_name,
|
||||
)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s: Connect task didn't finish before disconnect",
|
||||
self.log_name,
|
||||
)
|
||||
|
||||
self._expected_disconnect = True
|
||||
if self._handshake_complete:
|
||||
|
@ -78,18 +78,18 @@ def noise_connection_params() -> ConnectionParams:
|
||||
)
|
||||
|
||||
|
||||
async def on_stop(expected_disconnect: bool) -> None:
|
||||
def on_stop(expected_disconnect: bool) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(connection_params: ConnectionParams) -> APIConnection:
|
||||
return PatchableAPIConnection(connection_params, on_stop)
|
||||
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
||||
return PatchableAPIConnection(noise_connection_params, on_stop)
|
||||
return PatchableAPIConnection(noise_connection_params, on_stop, True, None)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||
|
@ -124,7 +124,9 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
||||
class MockConnection(APIConnection):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Swallow args."""
|
||||
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
|
||||
super().__init__(
|
||||
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
|
||||
)
|
||||
|
||||
def process_packet(self, type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
@ -1068,29 +1069,6 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
|
||||
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_gatt_read_descriptor(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test bluetooth_gatt_read_descriptor."""
|
||||
client, connection, transport, protocol = api_client
|
||||
read_task = asyncio.create_task(client.bluetooth_gatt_read_descriptor(1234, 1234))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
other_response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=4567, data=b"4567"
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=1234, data=b"1234"
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert await read_task == b"1234"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bluetooth_gatt_get_services(
|
||||
api_client: tuple[
|
||||
@ -1374,3 +1352,37 @@ async def test_subscribe_service_calls(auth_client: APIClient) -> None:
|
||||
service_msg = HomeassistantServiceResponse(service="bob")
|
||||
await send(service_msg)
|
||||
on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_debug(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test set_debug."""
|
||||
client, connection, transport, protocol = api_client
|
||||
response: message.Message = DeviceInfoResponse(
|
||||
name="realname",
|
||||
friendly_name="My Device",
|
||||
has_deep_sleep=True,
|
||||
)
|
||||
|
||||
caplog.set_level(logging.DEBUG)
|
||||
|
||||
client.set_debug(True)
|
||||
assert client.log_name == "mydevice.local"
|
||||
device_info_task = asyncio.create_task(client.device_info())
|
||||
await asyncio.sleep(0)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await device_info_task
|
||||
|
||||
assert "My Device" in caplog.text
|
||||
caplog.clear()
|
||||
client.set_debug(False)
|
||||
device_info_task = asyncio.create_task(client.device_info())
|
||||
await asyncio.sleep(0)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await device_info_task
|
||||
assert "My Device" not in caplog.text
|
||||
|
@ -1,15 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi import APIClient
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
DeviceInfoResponse,
|
||||
DisconnectRequest,
|
||||
@ -491,6 +494,7 @@ async def test_force_disconnect_fails(
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
await conn.force_disconnect()
|
||||
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -702,3 +706,35 @@ async def test_respond_to_ping_request(
|
||||
ping_response_bytes = b"\x00\x00\x08"
|
||||
assert transport.write.call_count == 1
|
||||
assert transport.write.mock_calls == [call(ping_response_bytes)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_protobuf_message_type_logged(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test unknown protobuf messages are logged but do not cause the connection to collapse."""
|
||||
client, connection, transport, protocol = api_client
|
||||
response: message.Message = DeviceInfoResponse(
|
||||
name="realname",
|
||||
friendly_name="My Device",
|
||||
has_deep_sleep=True,
|
||||
)
|
||||
caplog.set_level(logging.DEBUG)
|
||||
client.set_debug(True)
|
||||
bytes_ = response.SerializeToString()
|
||||
message_with_invalid_protobuf_number = (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(bytes_))
|
||||
+ _cached_varuint_to_bytes(16385)
|
||||
+ bytes_
|
||||
)
|
||||
|
||||
mock_data_received(protocol, message_with_invalid_protobuf_number)
|
||||
|
||||
assert "Skipping unknown message type 16385" in caplog.text
|
||||
assert connection.is_connected
|
||||
await connection.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
Loading…
Reference in New Issue
Block a user