mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-23 21:51:34 +01:00
Ensure expected_disconnect is True when sending DisconnectResponse fails (#646)
This commit is contained in:
parent
041cbad89f
commit
f783438a7d
@ -883,8 +883,11 @@ class APIConnection:
|
||||
self, _msg: DisconnectRequest
|
||||
) -> None:
|
||||
"""Handle a DisconnectRequest."""
|
||||
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
|
||||
# Set _expected_disconnect to True before sending
|
||||
# the response if for some reason sending the response
|
||||
# fails we will still mark the disconnect as expected
|
||||
self._expected_disconnect = True
|
||||
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
|
||||
self._cleanup()
|
||||
|
||||
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
||||
|
@ -8,25 +8,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi import APIClient
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
DeviceInfoResponse,
|
||||
DisconnectRequest,
|
||||
HelloResponse,
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
)
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionState
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
SocketAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
|
||||
from .common import (
|
||||
async_fire_time_changed,
|
||||
connect,
|
||||
generate_plaintext_packet,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
utcnow,
|
||||
@ -461,3 +465,81 @@ async def test_connect_correct_password(
|
||||
await connect_task
|
||||
|
||||
assert conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_disconnect_fails(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
await conn.force_disconnect()
|
||||
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_fails_to_send_response(
|
||||
connection_params: ConnectionParams,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
resolve_host,
|
||||
socket_socket,
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
client = APIClient(
|
||||
address="mydevice.local",
|
||||
port=6052,
|
||||
password=None,
|
||||
)
|
||||
expected_disconnect = None
|
||||
|
||||
async def _on_stop(_expected_disconnect: bool) -> None:
|
||||
nonlocal expected_disconnect
|
||||
expected_disconnect = _expected_disconnect
|
||||
|
||||
conn = APIConnection(connection_params, _on_stop)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
send_plaintext_hello(protocol)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
transport.reset_mock()
|
||||
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
||||
with pytest.raises(SocketAPIError), patch.object(
|
||||
protocol, "_writer", side_effect=OSError
|
||||
):
|
||||
disconnect_request = DisconnectRequest()
|
||||
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
||||
|
||||
# Wait one loop iteration for the disconnect to be processed
|
||||
await asyncio.sleep(0)
|
||||
assert expected_disconnect is True
|
||||
|
Loading…
Reference in New Issue
Block a user