Avoid expensive debug logging checks in packet processor (#700)
This commit is contained in:
parent
0d25cc92a0
commit
67661dbd7f
|
@ -168,7 +168,17 @@ def _stringify_or_none(value: str | None) -> str | None:
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
class APIClient:
|
class APIClient:
|
||||||
|
"""The ESPHome API client.
|
||||||
|
|
||||||
|
This class is the main entrypoint for interacting with the API.
|
||||||
|
|
||||||
|
It is recommended to use this class in combination with the
|
||||||
|
ReconnectLogic class to automatically reconnect to the device
|
||||||
|
if the connection is lost.
|
||||||
|
"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
|
"_debug_enabled",
|
||||||
"_params",
|
"_params",
|
||||||
"_connection",
|
"_connection",
|
||||||
"cached_name",
|
"cached_name",
|
||||||
|
@ -205,6 +215,7 @@ class APIClient:
|
||||||
Can be used to prevent accidentally connecting to a different device if
|
Can be used to prevent accidentally connecting to a different device if
|
||||||
IP passed as address but DHCP reassigned IP.
|
IP passed as address but DHCP reassigned IP.
|
||||||
"""
|
"""
|
||||||
|
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
self._params = ConnectionParams(
|
self._params = ConnectionParams(
|
||||||
address=str(address),
|
address=str(address),
|
||||||
port=port,
|
port=port,
|
||||||
|
@ -223,6 +234,12 @@ class APIClient:
|
||||||
self._on_stop_task: asyncio.Task[None] | None = None
|
self._on_stop_task: asyncio.Task[None] | None = None
|
||||||
self._set_log_name()
|
self._set_log_name()
|
||||||
|
|
||||||
|
def set_debug(self, enabled: bool) -> None:
|
||||||
|
"""Enable debug logging."""
|
||||||
|
self._debug_enabled = enabled
|
||||||
|
if self._connection:
|
||||||
|
self._connection.set_debug(enabled)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zeroconf_manager(self) -> ZeroconfManager:
|
def zeroconf_manager(self) -> ZeroconfManager:
|
||||||
return self._params.zeroconf_manager
|
return self._params.zeroconf_manager
|
||||||
|
@ -299,7 +316,10 @@ class APIClient:
|
||||||
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
||||||
|
|
||||||
self._connection = APIConnection(
|
self._connection = APIConnection(
|
||||||
self._params, partial(self._on_stop, on_stop), log_name=self.log_name
|
self._params,
|
||||||
|
partial(self._on_stop, on_stop),
|
||||||
|
self._debug_enabled,
|
||||||
|
self.log_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -556,7 +576,6 @@ class APIClient:
|
||||||
has_cache: bool = False,
|
has_cache: bool = False,
|
||||||
address_type: int | None = None,
|
address_type: int | None = None,
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
debug = _LOGGER.isEnabledFor(logging.DEBUG)
|
|
||||||
connect_future: asyncio.Future[None] = self._loop.create_future()
|
connect_future: asyncio.Future[None] = self._loop.create_future()
|
||||||
|
|
||||||
if has_cache:
|
if has_cache:
|
||||||
|
@ -570,7 +589,7 @@ class APIClient:
|
||||||
# of the connection. This can crash the esp if the service list is too large.
|
# of the connection. This can crash the esp if the service list is too large.
|
||||||
request_type = BluetoothDeviceRequestType.CONNECT
|
request_type = BluetoothDeviceRequestType.CONNECT
|
||||||
|
|
||||||
if debug:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug("%s: Using connection version %s", address, request_type)
|
_LOGGER.debug("%s: Using connection version %s", address, request_type)
|
||||||
|
|
||||||
unsub = self._get_connection().send_message_callback_response(
|
unsub = self._get_connection().send_message_callback_response(
|
||||||
|
@ -604,7 +623,7 @@ class APIClient:
|
||||||
# the slot is recovered before the timeout is raised
|
# the slot is recovered before the timeout is raised
|
||||||
# to avoid race were we run out even though we have a slot.
|
# to avoid race were we run out even though we have a slot.
|
||||||
addr = to_human_readable_address(address)
|
addr = to_human_readable_address(address)
|
||||||
if debug:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
|
_LOGGER.debug("%s: Connecting timed out, waiting for disconnect", addr)
|
||||||
disconnect_timed_out = (
|
disconnect_timed_out = (
|
||||||
not await self._bluetooth_device_disconnect_guard_timeout(
|
not await self._bluetooth_device_disconnect_guard_timeout(
|
||||||
|
@ -640,7 +659,7 @@ class APIClient:
|
||||||
try:
|
try:
|
||||||
await self.bluetooth_device_disconnect(address, timeout=timeout)
|
await self.bluetooth_device_disconnect(address, timeout=timeout)
|
||||||
except TimeoutAPIError:
|
except TimeoutAPIError:
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if self._debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Disconnect timed out: %s",
|
"%s: Disconnect timed out: %s",
|
||||||
to_human_readable_address(address),
|
to_human_readable_address(address),
|
||||||
|
|
|
@ -79,7 +79,7 @@ cdef class APIConnection:
|
||||||
cdef bint _send_pending_ping
|
cdef bint _send_pending_ping
|
||||||
cdef public bint is_connected
|
cdef public bint is_connected
|
||||||
cdef bint _handshake_complete
|
cdef bint _handshake_complete
|
||||||
cdef object _debug_enabled
|
cdef bint _debug_enabled
|
||||||
cdef public str received_name
|
cdef public str received_name
|
||||||
cdef public object resolved_addr_info
|
cdef public object resolved_addr_info
|
||||||
|
|
||||||
|
|
|
@ -128,6 +128,8 @@ class APIConnection:
|
||||||
|
|
||||||
An instance of this class may only be used once, for every new connection
|
An instance of this class may only be used once, for every new connection
|
||||||
a new instance should be established.
|
a new instance should be established.
|
||||||
|
|
||||||
|
This class should only be created from APIClient and should not be used directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
|
@ -161,7 +163,8 @@ class APIConnection:
|
||||||
self,
|
self,
|
||||||
params: ConnectionParams,
|
params: ConnectionParams,
|
||||||
on_stop: Callable[[bool], None],
|
on_stop: Callable[[bool], None],
|
||||||
log_name: str | None = None,
|
debug_enabled: bool,
|
||||||
|
log_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._params = params
|
self._params = params
|
||||||
self.on_stop: Callable[[bool], None] | None = on_stop
|
self.on_stop: Callable[[bool], None] | None = on_stop
|
||||||
|
@ -195,7 +198,7 @@ class APIConnection:
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self._handshake_complete = False
|
self._handshake_complete = False
|
||||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
self._debug_enabled = debug_enabled
|
||||||
self.received_name: str = ""
|
self.received_name: str = ""
|
||||||
self.resolved_addr_info: hr.AddrInfo | None = None
|
self.resolved_addr_info: hr.AddrInfo | None = None
|
||||||
|
|
||||||
|
@ -214,7 +217,8 @@ class APIConnection:
|
||||||
return
|
return
|
||||||
was_connected = self.is_connected
|
was_connected = self.is_connected
|
||||||
self._set_connection_state(ConnectionState.CLOSED)
|
self._set_connection_state(ConnectionState.CLOSED)
|
||||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
if self._debug_enabled:
|
||||||
|
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||||
for fut in self._read_exception_futures:
|
for fut in self._read_exception_futures:
|
||||||
if fut.done():
|
if fut.done():
|
||||||
continue
|
continue
|
||||||
|
@ -262,6 +266,10 @@ class APIConnection:
|
||||||
self.on_stop = None
|
self.on_stop = None
|
||||||
on_stop(self._expected_disconnect)
|
on_stop(self._expected_disconnect)
|
||||||
|
|
||||||
|
def set_debug(self, enable: bool) -> None:
|
||||||
|
"""Enable or disable debug logging."""
|
||||||
|
self._debug_enabled = enable
|
||||||
|
|
||||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||||
"""Step 1 in connect process: resolve the address."""
|
"""Step 1 in connect process: resolve the address."""
|
||||||
try:
|
try:
|
||||||
|
@ -278,7 +286,6 @@ class APIConnection:
|
||||||
|
|
||||||
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||||
"""Step 2 in connect process: connect the socket."""
|
"""Step 2 in connect process: connect the socket."""
|
||||||
debug_enable = self._debug_enabled()
|
|
||||||
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
||||||
self._socket = sock
|
self._socket = sock
|
||||||
sock.setblocking(False)
|
sock.setblocking(False)
|
||||||
|
@ -294,7 +301,7 @@ class APIConnection:
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
|
|
||||||
if debug_enable is True:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Connecting to %s:%s (%s)",
|
"%s: Connecting to %s:%s (%s)",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
|
@ -312,7 +319,7 @@ class APIConnection:
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||||
|
|
||||||
if debug_enable is True:
|
if self._debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Opened socket to %s:%s (%s)",
|
"%s: Opened socket to %s:%s (%s)",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
|
@ -403,13 +410,14 @@ class APIConnection:
|
||||||
|
|
||||||
def _process_hello_resp(self, resp: HelloResponse) -> None:
|
def _process_hello_resp(self, resp: HelloResponse) -> None:
|
||||||
"""Process a HelloResponse."""
|
"""Process a HelloResponse."""
|
||||||
_LOGGER.debug(
|
if self._debug_enabled:
|
||||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
_LOGGER.debug(
|
||||||
self.log_name,
|
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||||
resp.server_info,
|
self.log_name,
|
||||||
resp.api_version_major,
|
resp.server_info,
|
||||||
resp.api_version_minor,
|
resp.api_version_major,
|
||||||
)
|
resp.api_version_minor,
|
||||||
|
)
|
||||||
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
||||||
if api_version.major > 2:
|
if api_version.major > 2:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
|
@ -456,7 +464,7 @@ class APIConnection:
|
||||||
self._pong_timer = loop.call_at(
|
self._pong_timer = loop.call_at(
|
||||||
now + self._keep_alive_timeout, self._async_pong_not_received
|
now + self._keep_alive_timeout, self._async_pong_not_received
|
||||||
)
|
)
|
||||||
elif self._debug_enabled() is True:
|
elif self._debug_enabled:
|
||||||
#
|
#
|
||||||
# We haven't reached the ping response (pong) timeout yet
|
# We haven't reached the ping response (pong) timeout yet
|
||||||
# and we haven't seen a response to the last ping
|
# and we haven't seen a response to the last ping
|
||||||
|
@ -485,11 +493,12 @@ class APIConnection:
|
||||||
"""Ping not received."""
|
"""Ping not received."""
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
return
|
return
|
||||||
_LOGGER.debug(
|
if self._debug_enabled:
|
||||||
"%s: Ping response not received after %s seconds",
|
_LOGGER.debug(
|
||||||
self.log_name,
|
"%s: Ping response not received after %s seconds",
|
||||||
self._keep_alive_timeout,
|
self.log_name,
|
||||||
)
|
self._keep_alive_timeout,
|
||||||
|
)
|
||||||
self.report_fatal_error(
|
self.report_fatal_error(
|
||||||
PingFailedAPIError(
|
PingFailedAPIError(
|
||||||
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
||||||
|
@ -608,14 +617,14 @@ class APIConnection:
|
||||||
)
|
)
|
||||||
|
|
||||||
packets: list[tuple[int, bytes]] = []
|
packets: list[tuple[int, bytes]] = []
|
||||||
debug_enabled = self._debug_enabled()
|
debug_enabled = self._debug_enabled
|
||||||
|
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
msg_type = type(msg)
|
msg_type = type(msg)
|
||||||
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
|
||||||
raise ValueError(f"Message type id not found for type {msg_type}")
|
raise ValueError(f"Message type id not found for type {msg_type}")
|
||||||
|
|
||||||
if debug_enabled is True:
|
if debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
|
"%s: Sending %s: %s", self.log_name, msg_type.__name__, msg
|
||||||
)
|
)
|
||||||
|
@ -786,12 +795,14 @@ class APIConnection:
|
||||||
|
|
||||||
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||||
"""Process an incoming packet."""
|
"""Process an incoming packet."""
|
||||||
|
debug_enabled = self._debug_enabled
|
||||||
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
||||||
_LOGGER.debug(
|
if debug_enabled:
|
||||||
"%s: Skipping message type %s",
|
_LOGGER.debug(
|
||||||
self.log_name,
|
"%s: Skipping unknown message type %s",
|
||||||
msg_type_proto,
|
self.log_name,
|
||||||
)
|
msg_type_proto,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -818,7 +829,7 @@ class APIConnection:
|
||||||
|
|
||||||
msg_type = type(msg)
|
msg_type = type(msg)
|
||||||
|
|
||||||
if self._debug_enabled() is True:
|
if debug_enabled:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Got message of type %s: %s",
|
"%s: Got message of type %s: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
|
@ -891,10 +902,11 @@ class APIConnection:
|
||||||
self._fatal_exception = TimeoutAPIError(
|
self._fatal_exception = TimeoutAPIError(
|
||||||
"Timed out waiting to finish connect before disconnecting"
|
"Timed out waiting to finish connect before disconnecting"
|
||||||
)
|
)
|
||||||
_LOGGER.debug(
|
if self._debug_enabled:
|
||||||
"%s: Connect task didn't finish before disconnect",
|
_LOGGER.debug(
|
||||||
self.log_name,
|
"%s: Connect task didn't finish before disconnect",
|
||||||
)
|
self.log_name,
|
||||||
|
)
|
||||||
|
|
||||||
self._expected_disconnect = True
|
self._expected_disconnect = True
|
||||||
if self._handshake_complete:
|
if self._handshake_complete:
|
||||||
|
|
|
@ -78,18 +78,18 @@ def noise_connection_params() -> ConnectionParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def on_stop(expected_disconnect: bool) -> None:
|
def on_stop(expected_disconnect: bool) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def conn(connection_params: ConnectionParams) -> APIConnection:
|
def conn(connection_params: ConnectionParams) -> APIConnection:
|
||||||
return PatchableAPIConnection(connection_params, on_stop)
|
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection:
|
||||||
return PatchableAPIConnection(noise_connection_params, on_stop)
|
return PatchableAPIConnection(noise_connection_params, on_stop, True, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||||
|
|
|
@ -124,7 +124,9 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
||||||
class MockConnection(APIConnection):
|
class MockConnection(APIConnection):
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Swallow args."""
|
"""Swallow args."""
|
||||||
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
|
super().__init__(
|
||||||
|
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def process_packet(self, type_: int, data: bytes):
|
def process_packet(self, type_: int, data: bytes):
|
||||||
packets.append((type_, data))
|
packets.append((type_, data))
|
||||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
|
@ -1068,29 +1069,6 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
|
||||||
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_bluetooth_gatt_read_descriptor(
|
|
||||||
api_client: tuple[
|
|
||||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
"""Test bluetooth_gatt_read_descriptor."""
|
|
||||||
client, connection, transport, protocol = api_client
|
|
||||||
read_task = asyncio.create_task(client.bluetooth_gatt_read_descriptor(1234, 1234))
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
other_response: message.Message = BluetoothGATTReadResponse(
|
|
||||||
address=1234, handle=4567, data=b"4567"
|
|
||||||
)
|
|
||||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTReadResponse(
|
|
||||||
address=1234, handle=1234, data=b"1234"
|
|
||||||
)
|
|
||||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
|
||||||
assert await read_task == b"1234"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bluetooth_gatt_get_services(
|
async def test_bluetooth_gatt_get_services(
|
||||||
api_client: tuple[
|
api_client: tuple[
|
||||||
|
@ -1374,3 +1352,37 @@ async def test_subscribe_service_calls(auth_client: APIClient) -> None:
|
||||||
service_msg = HomeassistantServiceResponse(service="bob")
|
service_msg = HomeassistantServiceResponse(service="bob")
|
||||||
await send(service_msg)
|
await send(service_msg)
|
||||||
on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg))
|
on_service_call.assert_called_with(HomeassistantServiceCall.from_pb(service_msg))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_debug(
|
||||||
|
api_client: tuple[
|
||||||
|
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||||
|
],
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test set_debug."""
|
||||||
|
client, connection, transport, protocol = api_client
|
||||||
|
response: message.Message = DeviceInfoResponse(
|
||||||
|
name="realname",
|
||||||
|
friendly_name="My Device",
|
||||||
|
has_deep_sleep=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
|
||||||
|
client.set_debug(True)
|
||||||
|
assert client.log_name == "mydevice.local"
|
||||||
|
device_info_task = asyncio.create_task(client.device_info())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
await device_info_task
|
||||||
|
|
||||||
|
assert "My Device" in caplog.text
|
||||||
|
caplog.clear()
|
||||||
|
client.set_debug(False)
|
||||||
|
device_info_task = asyncio.create_task(client.device_info())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
await device_info_task
|
||||||
|
assert "My Device" not in caplog.text
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from google.protobuf import message
|
||||||
|
|
||||||
from aioesphomeapi import APIClient
|
from aioesphomeapi import APIClient
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||||
from aioesphomeapi.api_pb2 import (
|
from aioesphomeapi.api_pb2 import (
|
||||||
DeviceInfoResponse,
|
DeviceInfoResponse,
|
||||||
DisconnectRequest,
|
DisconnectRequest,
|
||||||
|
@ -491,6 +494,7 @@ async def test_force_disconnect_fails(
|
||||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||||
await conn.force_disconnect()
|
await conn.force_disconnect()
|
||||||
assert "Failed to send (forced) disconnect request" in caplog.text
|
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -702,3 +706,35 @@ async def test_respond_to_ping_request(
|
||||||
ping_response_bytes = b"\x00\x00\x08"
|
ping_response_bytes = b"\x00\x00\x08"
|
||||||
assert transport.write.call_count == 1
|
assert transport.write.call_count == 1
|
||||||
assert transport.write.mock_calls == [call(ping_response_bytes)]
|
assert transport.write.mock_calls == [call(ping_response_bytes)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_protobuf_message_type_logged(
|
||||||
|
api_client: tuple[
|
||||||
|
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||||
|
],
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test unknown protobuf messages are logged but do not cause the connection to collapse."""
|
||||||
|
client, connection, transport, protocol = api_client
|
||||||
|
response: message.Message = DeviceInfoResponse(
|
||||||
|
name="realname",
|
||||||
|
friendly_name="My Device",
|
||||||
|
has_deep_sleep=True,
|
||||||
|
)
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
client.set_debug(True)
|
||||||
|
bytes_ = response.SerializeToString()
|
||||||
|
message_with_invalid_protobuf_number = (
|
||||||
|
b"\0"
|
||||||
|
+ _cached_varuint_to_bytes(len(bytes_))
|
||||||
|
+ _cached_varuint_to_bytes(16385)
|
||||||
|
+ bytes_
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_data_received(protocol, message_with_invalid_protobuf_number)
|
||||||
|
|
||||||
|
assert "Skipping unknown message type 16385" in caplog.text
|
||||||
|
assert connection.is_connected
|
||||||
|
await connection.force_disconnect()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
Loading…
Reference in New Issue