diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index ef79f02..b7e41b2 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -168,7 +168,17 @@ def _stringify_or_none(value: str | None) -> str | None: # pylint: disable=too-many-public-methods class APIClient: + """The ESPHome API client. + + This class is the main entrypoint for interacting with the API. + + It is recommended to use this class in combination with the + ReconnectLogic class to automatically reconnect to the device + if the connection is lost. + """ + __slots__ = ( + "_debug_enabled", "_params", "_connection", "cached_name", @@ -205,6 +215,7 @@ class APIClient: Can be used to prevent accidentally connecting to a different device if IP passed as address but DHCP reassigned IP. """ + self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._params = ConnectionParams( address=str(address), port=port, @@ -223,6 +234,12 @@ class APIClient: self._on_stop_task: asyncio.Task[None] | None = None self._set_log_name() + def set_debug(self, enabled: bool) -> None: + """Enable debug logging.""" + self._debug_enabled = enabled + if self._connection: + self._connection.set_debug(enabled) + @property def zeroconf_manager(self) -> ZeroconfManager: return self._params.zeroconf_manager @@ -299,7 +316,10 @@ class APIClient: raise APIConnectionError(f"Already connected to {self.log_name}!") self._connection = APIConnection( - self._params, partial(self._on_stop, on_stop), log_name=self.log_name + self._params, + partial(self._on_stop, on_stop), + self._debug_enabled, + self.log_name, ) try: @@ -556,7 +576,6 @@ class APIClient: has_cache: bool = False, address_type: int | None = None, ) -> Callable[[], None]: - debug = _LOGGER.isEnabledFor(logging.DEBUG) connect_future: asyncio.Future[None] = self._loop.create_future() if has_cache: @@ -570,7 +589,7 @@ class APIClient: # of the connection. This can crash the esp if the service list is too large. request_type = BluetoothDeviceRequestType.CONNECT - if debug: + if self._debug_enabled: _LOGGER.debug("%s: Using connection version %s", address, request_type) unsub = self._get_connection().send_message_callback_response( @@ -604,7 +623,7 @@ class APIClient: # the slot is recovered before the timeout is raised # to avoid race were we run out even though we have a slot. addr = to_human_readable_address(address) - if debug: + if self._debug_enabled: _LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr) disconnect_timed_out = ( not await self._bluetooth_device_disconnect_guard_timeout( @@ -640,7 +659,7 @@ class APIClient: try: await self.bluetooth_device_disconnect(address, timeout=timeout) except TimeoutAPIError: - if _LOGGER.isEnabledFor(logging.DEBUG): + if self._debug_enabled: _LOGGER.debug( "%s: Disconnect timed out: %s", to_human_readable_address(address), diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 6e850e0..38002c9 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -79,7 +79,7 @@ cdef class APIConnection: cdef bint _send_pending_ping cdef public bint is_connected cdef bint _handshake_complete - cdef object _debug_enabled + cdef bint _debug_enabled cdef public str received_name cdef public object resolved_addr_info diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 97e8029..bdf2290 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -128,6 +128,8 @@ class APIConnection: An instance of this class may only be used once, for every new connection a new instance should be established. + + This class should only be created from APIClient and should not be used directly. """ __slots__ = ( @@ -161,7 +163,8 @@ class APIConnection: self, params: ConnectionParams, on_stop: Callable[[bool], None], - log_name: str | None = None, + debug_enabled: bool, + log_name: str | None, ) -> None: self._params = params self.on_stop: Callable[[bool], None] | None = on_stop @@ -195,7 +198,7 @@ class APIConnection: self._loop = asyncio.get_event_loop() self.is_connected = False self._handshake_complete = False - self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) + self._debug_enabled = debug_enabled self.received_name: str = "" self.resolved_addr_info: hr.AddrInfo | None = None @@ -214,7 +217,8 @@ class APIConnection: return was_connected = self.is_connected self._set_connection_state(ConnectionState.CLOSED) - _LOGGER.debug("Cleaning up connection to %s", self.log_name) + if self._debug_enabled: + _LOGGER.debug("Cleaning up connection to %s", self.log_name) for fut in self._read_exception_futures: if fut.done(): continue @@ -262,6 +266,10 @@ class APIConnection: self.on_stop = None on_stop(self._expected_disconnect) + def set_debug(self, enable: bool) -> None: + """Enable or disable debug logging.""" + self._debug_enabled = enable + async def _connect_resolve_host(self) -> hr.AddrInfo: """Step 1 in connect process: resolve the address.""" try: @@ -278,7 +286,6 @@ class APIConnection: async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None: """Step 2 in connect process: connect the socket.""" - debug_enable = self._debug_enabled() sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto) self._socket = sock sock.setblocking(False) @@ -294,7 +301,7 @@ class APIConnection: err, ) - if debug_enable is True: + if self._debug_enabled: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, @@ -312,7 +319,7 @@ class APIConnection: except OSError as err: raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err - if debug_enable is True: + if self._debug_enabled: _LOGGER.debug( "%s: Opened socket to %s:%s (%s)", self.log_name, @@ -403,13 +410,14 @@ class APIConnection: def _process_hello_resp(self, resp: HelloResponse) -> None: """Process a HelloResponse.""" - _LOGGER.debug( - "%s: Successfully connected ('%s' API=%s.%s)", - self.log_name, - resp.server_info, - resp.api_version_major, - resp.api_version_minor, - ) + if self._debug_enabled: + _LOGGER.debug( + "%s: Successfully connected ('%s' API=%s.%s)", + self.log_name, + resp.server_info, + resp.api_version_major, + resp.api_version_minor, + ) api_version = APIVersion(resp.api_version_major, resp.api_version_minor) if api_version.major > 2: _LOGGER.error( @@ -456,7 +464,7 @@ class APIConnection: self._pong_timer = loop.call_at( now + self._keep_alive_timeout, self._async_pong_not_received ) - elif self._debug_enabled() is True: + elif self._debug_enabled: # # We haven't reached the ping response (pong) timeout yet # and we haven't seen a response to the last ping @@ -485,11 +493,12 @@ class APIConnection: """Ping not received.""" if not self.is_connected: return - _LOGGER.debug( - "%s: Ping response not received after %s seconds", - self.log_name, - self._keep_alive_timeout, - ) + if self._debug_enabled: + _LOGGER.debug( + "%s: Ping response not received after %s seconds", + self.log_name, + self._keep_alive_timeout, + ) self.report_fatal_error( PingFailedAPIError( f"Ping response not received after {self._keep_alive_timeout} seconds" @@ -608,14 +617,14 @@ class APIConnection: ) packets: list[tuple[int, bytes]] = [] - debug_enabled = self._debug_enabled() + debug_enabled = self._debug_enabled for msg in msgs: msg_type = type(msg) if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: raise ValueError(f"Message type id not found for type {msg_type}") - if debug_enabled is True: + if debug_enabled: _LOGGER.debug( "%s: Sending %s: %s", self.log_name, msg_type.__name__, msg ) @@ -786,12 +795,14 @@ class APIConnection: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: """Process an incoming packet.""" + debug_enabled = self._debug_enabled if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: - _LOGGER.debug( - "%s: Skipping message type %s", - self.log_name, - msg_type_proto, - ) + if debug_enabled: + _LOGGER.debug( + "%s: Skipping unknown message type %s", + self.log_name, + msg_type_proto, + ) return try: @@ -818,7 +829,7 @@ class APIConnection: msg_type = type(msg) - if self._debug_enabled() is True: + if debug_enabled: _LOGGER.debug( "%s: Got message of type %s: %s", self.log_name, @@ -891,10 +902,11 @@ class APIConnection: self._fatal_exception = TimeoutAPIError( "Timed out waiting to finish connect before disconnecting" ) - _LOGGER.debug( - "%s: Connect task didn't finish before disconnect", - self.log_name, - ) + if self._debug_enabled: + _LOGGER.debug( + "%s: Connect task didn't finish before disconnect", + self.log_name, + ) self._expected_disconnect = True if self._handshake_complete: diff --git a/tests/conftest.py b/tests/conftest.py index 6bd41ff..e05ede7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,18 +78,18 @@ def noise_connection_params() -> ConnectionParams: ) -async def on_stop(expected_disconnect: bool) -> None: +def on_stop(expected_disconnect: bool) -> None: pass @pytest.fixture def conn(connection_params: ConnectionParams) -> APIConnection: - return PatchableAPIConnection(connection_params, on_stop) + return PatchableAPIConnection(connection_params, on_stop, True, None) @pytest.fixture def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection: - return PatchableAPIConnection(noise_connection_params, on_stop) + return PatchableAPIConnection(noise_connection_params, on_stop, True, None) @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index b69a073..38d5235 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -124,7 +124,9 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: class MockConnection(APIConnection): def __init__(self, *args: Any, **kwargs: Any) -> None: """Swallow args.""" - super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs) + super().__init__( + get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs + ) def process_packet(self, type_: int, data: bytes): packets.append((type_, data)) diff --git a/tests/test_client.py b/tests/test_client.py index 2cc4ff8..dfa1b6b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import itertools +import logging from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -1068,29 +1069,6 @@ async def test_bluetooth_gatt_write_descriptor_without_response( await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0) -@pytest.mark.asyncio -async def test_bluetooth_gatt_read_descriptor( - api_client: tuple[ - APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper - ], -) -> None: - """Test bluetooth_gatt_read_descriptor.""" - client, connection, transport, protocol = api_client - read_task = asyncio.create_task(client.bluetooth_gatt_read_descriptor(1234, 1234)) - await asyncio.sleep(0) - - other_response: message.Message = BluetoothGATTReadResponse( - address=1234, handle=4567, data=b"4567" - ) - mock_data_received(protocol, generate_plaintext_packet(other_response)) - - response: message.Message = BluetoothGATTReadResponse( - address=1234, handle=1234, data=b"1234" - ) - mock_data_received(protocol, generate_plaintext_packet(response)) - assert await read_task == b"1234" - - @pytest.mark.asyncio async def test_bluetooth_gatt_get_services( api_client: tuple[ @@ -1374,3 +1352,37 @@ async def test_subscribe_service_calls(auth_client: APIClient) -> None: service_msg = HomeassistantServiceResponse(service="bob") await send(service_msg) on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg)) + + +@pytest.mark.asyncio +async def test_set_debug( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test set_debug.""" + client, connection, transport, protocol = api_client + response: message.Message = DeviceInfoResponse( + name="realname", + friendly_name="My Device", + has_deep_sleep=True, + ) + + caplog.set_level(logging.DEBUG) + + client.set_debug(True) + assert client.log_name == "mydevice.local" + device_info_task = asyncio.create_task(client.device_info()) + await asyncio.sleep(0) + mock_data_received(protocol, generate_plaintext_packet(response)) + await device_info_task + + assert "My Device" in caplog.text + caplog.clear() + client.set_debug(False) + device_info_task = asyncio.create_task(client.device_info()) + await asyncio.sleep(0) + mock_data_received(protocol, generate_plaintext_packet(response)) + await device_info_task + assert "My Device" not in caplog.text diff --git a/tests/test_connection.py b/tests/test_connection.py index de1d424..190fd42 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,15 +1,18 @@ from __future__ import annotations import asyncio +import logging from collections.abc import Coroutine from datetime import timedelta from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch import pytest +from google.protobuf import message from aioesphomeapi import APIClient from aioesphomeapi._frame_helper import APIPlaintextFrameHelper +from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes from aioesphomeapi.api_pb2 import ( DeviceInfoResponse, DisconnectRequest, @@ -491,6 +494,7 @@ async def test_force_disconnect_fails( with patch.object(protocol, "_writer", side_effect=OSError): await conn.force_disconnect() assert "Failed to send (forced) disconnect request" in caplog.text + await asyncio.sleep(0) @pytest.mark.asyncio @@ -702,3 +706,35 @@ async def test_respond_to_ping_request( ping_response_bytes = b"\x00\x00\x08" assert transport.write.call_count == 1 assert transport.write.mock_calls == [call(ping_response_bytes)] + + +@pytest.mark.asyncio +async def test_unknown_protobuf_message_type_logged( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test unknown protobuf messages are logged but do not cause the connection to collapse.""" + client, connection, transport, protocol = api_client + response: message.Message = DeviceInfoResponse( + name="realname", + friendly_name="My Device", + has_deep_sleep=True, + ) + caplog.set_level(logging.DEBUG) + client.set_debug(True) + bytes_ = response.SerializeToString() + message_with_invalid_protobuf_number = ( + b"\0" + + _cached_varuint_to_bytes(len(bytes_)) + + _cached_varuint_to_bytes(16385) + + bytes_ + ) + + mock_data_received(protocol, message_with_invalid_protobuf_number) + + assert "Skipping unknown message type 16385" in caplog.text + assert connection.is_connected + await connection.force_disconnect() + await asyncio.sleep(0)