diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index c2a5e21..fac3ea2 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -883,8 +883,11 @@ class APIConnection: self, _msg: DisconnectRequest ) -> None: """Handle a DisconnectRequest.""" - self.send_message(DISCONNECT_RESPONSE_MESSAGE) + # Set _expected_disconnect to True before sending + # the response if for some reason sending the response + # fails we will still mark the disconnect as expected self._expected_disconnect = True + self.send_message(DISCONNECT_RESPONSE_MESSAGE) self._cleanup() def _handle_ping_request_internal( # pylint: disable=unused-argument diff --git a/tests/test_connection.py b/tests/test_connection.py index 85967a7..5531624 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -8,25 +8,29 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from aioesphomeapi import APIClient from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi.api_pb2 import ( DeviceInfoResponse, + DisconnectRequest, HelloResponse, PingRequest, PingResponse, ) -from aioesphomeapi.connection import APIConnection, ConnectionState +from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import ( APIConnectionError, HandshakeAPIError, InvalidAuthAPIError, RequiresEncryptionAPIError, + SocketAPIError, TimeoutAPIError, ) from .common import ( async_fire_time_changed, connect, + generate_plaintext_packet, send_plaintext_connect_response, send_plaintext_hello, utcnow, @@ -461,3 +465,81 @@ async def test_connect_correct_password( await connect_task assert conn.is_connected + + +@pytest.mark.asyncio +async def test_force_disconnect_fails( + caplog: pytest.LogCaptureFixture, + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + assert conn.is_connected + + with patch.object(protocol, "_writer", side_effect=OSError): + await conn.force_disconnect() + assert "Failed to send (forced) disconnect request" in caplog.text + + +@pytest.mark.asyncio +async def test_disconnect_fails_to_send_response( + connection_params: ConnectionParams, + event_loop: asyncio.AbstractEventLoop, + resolve_host, + socket_socket, +) -> None: + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + client = APIClient( + address="mydevice.local", + port=6052, + password=None, + ) + expected_disconnect = None + + async def _on_stop(_expected_disconnect: bool) -> None: + nonlocal expected_disconnect + expected_disconnect = _expected_disconnect + + conn = APIConnection(connection_params, _on_stop) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + send_plaintext_hello(protocol) + client._connection = conn + await connect_task + transport.reset_mock() + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + assert conn.is_connected + + with pytest.raises(SocketAPIError), patch.object( + protocol, "_writer", side_effect=OSError + ): + disconnect_request = DisconnectRequest() + protocol.data_received(generate_plaintext_packet(disconnect_request)) + + # Wait one loop iteration for the disconnect to be processed + await asyncio.sleep(0) + assert expected_disconnect is True