Ensure expected_disconnect is True when sending DisconnectResponse fails (#646)

This commit is contained in:
J. Nick Koston 2023-11-20 19:08:29 +01:00 committed by GitHub
parent 041cbad89f
commit f783438a7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 2 deletions

View File

@ -883,8 +883,11 @@ class APIConnection:
self, _msg: DisconnectRequest
) -> None:
"""Handle a DisconnectRequest."""
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
# Set _expected_disconnect to True before sending
# the response if for some reason sending the response
# fails we will still mark the disconnect as expected
self._expected_disconnect = True
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
self._cleanup()
def _handle_ping_request_internal( # pylint: disable=unused-argument

View File

@ -8,25 +8,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import (
DeviceInfoResponse,
DisconnectRequest,
HelloResponse,
PingRequest,
PingResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionState
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import (
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
SocketAPIError,
TimeoutAPIError,
)
from .common import (
async_fire_time_changed,
connect,
generate_plaintext_packet,
send_plaintext_connect_response,
send_plaintext_hello,
utcnow,
@ -461,3 +465,81 @@ async def test_connect_correct_password(
await connect_task
assert conn.is_connected
@pytest.mark.asyncio
async def test_force_disconnect_fails(
caplog: pytest.LogCaptureFixture,
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
assert conn.is_connected
with patch.object(protocol, "_writer", side_effect=OSError):
await conn.force_disconnect()
assert "Failed to send (forced) disconnect request" in caplog.text
@pytest.mark.asyncio
async def test_disconnect_fails_to_send_response(
connection_params: ConnectionParams,
event_loop: asyncio.AbstractEventLoop,
resolve_host,
socket_socket,
) -> None:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
client = APIClient(
address="mydevice.local",
port=6052,
password=None,
)
expected_disconnect = None
async def _on_stop(_expected_disconnect: bool) -> None:
nonlocal expected_disconnect
expected_disconnect = _expected_disconnect
conn = APIConnection(connection_params, _on_stop)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
send_plaintext_hello(protocol)
client._connection = conn
await connect_task
transport.reset_mock()
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
assert conn.is_connected
with pytest.raises(SocketAPIError), patch.object(
protocol, "_writer", side_effect=OSError
):
disconnect_request = DisconnectRequest()
protocol.data_received(generate_plaintext_packet(disconnect_request))
# Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0)
assert expected_disconnect is True