mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-02 23:31:55 +01:00
Ensure expected_disconnect is True when sending DisconnectResponse fails (#646)
This commit is contained in:
parent
041cbad89f
commit
f783438a7d
@ -883,8 +883,11 @@ class APIConnection:
|
|||||||
self, _msg: DisconnectRequest
|
self, _msg: DisconnectRequest
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a DisconnectRequest."""
|
"""Handle a DisconnectRequest."""
|
||||||
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
|
# Set _expected_disconnect to True before sending
|
||||||
|
# the response if for some reason sending the response
|
||||||
|
# fails we will still mark the disconnect as expected
|
||||||
self._expected_disconnect = True
|
self._expected_disconnect = True
|
||||||
|
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
def _handle_ping_request_internal( # pylint: disable=unused-argument
|
||||||
|
@ -8,25 +8,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from aioesphomeapi import APIClient
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi.api_pb2 import (
|
from aioesphomeapi.api_pb2 import (
|
||||||
DeviceInfoResponse,
|
DeviceInfoResponse,
|
||||||
|
DisconnectRequest,
|
||||||
HelloResponse,
|
HelloResponse,
|
||||||
PingRequest,
|
PingRequest,
|
||||||
PingResponse,
|
PingResponse,
|
||||||
)
|
)
|
||||||
from aioesphomeapi.connection import APIConnection, ConnectionState
|
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||||
from aioesphomeapi.core import (
|
from aioesphomeapi.core import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
|
SocketAPIError,
|
||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import (
|
from .common import (
|
||||||
async_fire_time_changed,
|
async_fire_time_changed,
|
||||||
connect,
|
connect,
|
||||||
|
generate_plaintext_packet,
|
||||||
send_plaintext_connect_response,
|
send_plaintext_connect_response,
|
||||||
send_plaintext_hello,
|
send_plaintext_hello,
|
||||||
utcnow,
|
utcnow,
|
||||||
@ -461,3 +465,81 @@ async def test_connect_correct_password(
|
|||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_force_disconnect_fails(
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
plaintext_connect_task_with_login: tuple[
|
||||||
|
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
|
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
|
||||||
|
await connect_task
|
||||||
|
assert conn.is_connected
|
||||||
|
|
||||||
|
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||||
|
await conn.force_disconnect()
|
||||||
|
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_fails_to_send_response(
|
||||||
|
connection_params: ConnectionParams,
|
||||||
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
|
resolve_host,
|
||||||
|
socket_socket,
|
||||||
|
) -> None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
protocol: APIPlaintextFrameHelper | None = None
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
client = APIClient(
|
||||||
|
address="mydevice.local",
|
||||||
|
port=6052,
|
||||||
|
password=None,
|
||||||
|
)
|
||||||
|
expected_disconnect = None
|
||||||
|
|
||||||
|
async def _on_stop(_expected_disconnect: bool) -> None:
|
||||||
|
nonlocal expected_disconnect
|
||||||
|
expected_disconnect = _expected_disconnect
|
||||||
|
|
||||||
|
conn = APIConnection(connection_params, _on_stop)
|
||||||
|
|
||||||
|
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||||
|
nonlocal protocol
|
||||||
|
protocol = create_func()
|
||||||
|
protocol.connection_made(transport)
|
||||||
|
connected.set()
|
||||||
|
return transport, protocol
|
||||||
|
|
||||||
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
|
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||||
|
):
|
||||||
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
|
await connected.wait()
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
client._connection = conn
|
||||||
|
await connect_task
|
||||||
|
transport.reset_mock()
|
||||||
|
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
|
||||||
|
await connect_task
|
||||||
|
assert conn.is_connected
|
||||||
|
|
||||||
|
with pytest.raises(SocketAPIError), patch.object(
|
||||||
|
protocol, "_writer", side_effect=OSError
|
||||||
|
):
|
||||||
|
disconnect_request = DisconnectRequest()
|
||||||
|
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
||||||
|
|
||||||
|
# Wait one loop iteration for the disconnect to be processed
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert expected_disconnect is True
|
||||||
|
Loading…
Reference in New Issue
Block a user