diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 23e13bf..23e12ad 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -77,6 +77,8 @@ KEEP_ALIVE_TIMEOUT_RATIO = 4.5 # from the network. # +DISCONNECT_CONNECT_TIMEOUT = 5.0 +DISCONNECT_RESPONSE_TIMEOUT = 10.0 HANDSHAKE_TIMEOUT = 30.0 RESOLVE_TIMEOUT = 30.0 CONNECT_REQUEST_TIMEOUT = 30.0 @@ -747,7 +749,7 @@ class APIConnection: except asyncio_TimeoutError as err: timeout_expired = True raise TimeoutAPIError( - f"Timeout waiting for response for {type(send_msg)} after {timeout}s" + f"Timeout waiting for response to {type(send_msg).__name__} after {timeout}s" ) from err finally: if not timeout_expired: @@ -886,8 +888,13 @@ class APIConnection: # Try to wait for the handshake to finish so we can send # a disconnect request. If it doesn't finish in time # we will just close the socket. - _, pending = await asyncio.wait([self._finish_connect_task], timeout=5.0) + _, pending = await asyncio.wait( + [self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT + ) if pending: + self._fatal_exception = TimeoutAPIError( + "Timed out waiting to finish connect before disconnecting" + ) _LOGGER.debug( "%s: Connect task didn't finish before disconnect", self.log_name, @@ -901,12 +908,12 @@ class APIConnection: # as possible. try: await self.send_message_await_response( - DISCONNECT_REQUEST_MESSAGE, DisconnectResponse + DISCONNECT_REQUEST_MESSAGE, + DisconnectResponse, + timeout=DISCONNECT_RESPONSE_TIMEOUT, ) except APIConnectionError as err: - _LOGGER.error( - "%s: Failed to send disconnect request: %s", self.log_name, err - ) + _LOGGER.error("%s: disconnect request failed: %s", self.log_name, err) self._cleanup() diff --git a/tests/test_connection.py b/tests/test_connection.py index 5338739..eb96aab 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -3,14 +3,19 @@ from __future__ import annotations import asyncio import socket from datetime import timedelta -from typing import Any, Coroutine, Optional +from typing import Any, Coroutine, Generator, Optional from unittest.mock import AsyncMock import pytest from mock import MagicMock, patch from aioesphomeapi._frame_helper import APIPlaintextFrameHelper -from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse +from aioesphomeapi.api_pb2 import ( + DeviceInfoResponse, + HelloResponse, + PingRequest, + PingResponse, +) from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import ( APIConnectionError, @@ -121,6 +126,117 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop): assert conn.is_connected +@pytest.mark.asyncio +async def test_timeout_sending_message( + conn: APIConnection, + resolve_host: Coroutine[Any, Any, AddrInfo], + socket_socket: Generator[Any, Any, None], + event_loop: asyncio.AbstractEventLoop, + caplog: pytest.LogCaptureFixture, +) -> None: + loop = asyncio.get_event_loop() + protocol: Optional[APIPlaintextFrameHelper] = None + transport = MagicMock() + connected = asyncio.Event() + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + transport = MagicMock() + + with patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + protocol.data_received( + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + b"5stackatomproxy" + b"\x00\x00$" + b"\x00\x00\x04" + b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' + b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" + ) + + await connect_task + + with pytest.raises(TimeoutAPIError): + await conn.send_message_await_response_complex( + PingRequest(), None, None, (PingResponse,), timeout=0 + ) + + transport.reset_mock() + with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0): + await conn.disconnect() + + transport.write.assert_called_with(b"\x00\x00\x05") + + assert "disconnect request failed" in caplog.text + assert ( + " Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text + ) + + +@pytest.mark.asyncio +async def test_disconnect_when_not_fully_connected( + conn: APIConnection, + resolve_host: Coroutine[Any, Any, AddrInfo], + socket_socket: Generator[Any, Any, None], + event_loop: asyncio.AbstractEventLoop, + caplog: pytest.LogCaptureFixture, +) -> None: + loop = asyncio.get_event_loop() + protocol: Optional[APIPlaintextFrameHelper] = None + transport = MagicMock() + connected = asyncio.Event() + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + transport = MagicMock() + + with patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + + # Only send the first part of the handshake + # so we are stuck in the middle of the connection process + protocol.data_received( + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + ) + + await asyncio.sleep(0) + transport.reset_mock() + + with patch("aioesphomeapi.connection.DISCONNECT_CONNECT_TIMEOUT", 0.0), patch( + "aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0 + ): + await conn.disconnect() + + with pytest.raises( + APIConnectionError, + match="Timed out waiting to finish connect before disconnecting", + ): + await connect_task + + transport.write.assert_called_with(b"\x00\x00\x05") + + assert "disconnect request failed" in caplog.text + assert ( + " Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text + ) + + @pytest.mark.asyncio async def test_requires_encryption_propagates(conn: APIConnection): loop = asyncio.get_event_loop()