From 3abf9ff8d4ca98eebd9b56427d00bc073e336b11 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 25 Nov 2023 08:39:04 -0600 Subject: [PATCH] Make force_disconnect a normal function (#705) --- aioesphomeapi/client.py | 2 +- aioesphomeapi/connection.py | 2 +- tests/test_client.py | 15 +++++++++++++++ tests/test_connection.py | 10 +++++----- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index b7e41b2..23d8bd0 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -360,7 +360,7 @@ class APIClient: if self._connection is None: return if force: - await self._connection.force_disconnect() + self._connection.force_disconnect() else: await self._connection.disconnect() diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index c2de199..9131a09 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -915,7 +915,7 @@ class APIConnection: self._cleanup() - async def force_disconnect(self) -> None: + def force_disconnect(self) -> None: """Forcefully disconnect from the API.""" self._expected_disconnect = True if self._handshake_complete: diff --git a/tests/test_client.py b/tests/test_client.py index dfa1b6b..8d50a72 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1386,3 +1386,18 @@ async def test_set_debug( mock_data_received(protocol, generate_plaintext_packet(response)) await device_info_task assert "My Device" not in caplog.text + + +@pytest.mark.asyncio +async def test_force_disconnect( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test force disconnect can be called multiple times.""" + client, connection, transport, protocol = api_client + await client.disconnect(force=True) + assert connection.is_connected is False + await client.disconnect(force=False) + assert connection.is_connected is False diff --git a/tests/test_connection.py b/tests/test_connection.py index 241341b..16e6e8c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -206,7 +206,7 @@ async def test_plaintext_connection( assert isinstance(messages[1], DeviceInfoResponse) assert messages[1].name == "m5stackatomproxy" remove() - await conn.force_disconnect() + conn.force_disconnect() await asyncio.sleep(0) @@ -336,7 +336,7 @@ async def test_finish_connection_times_out( assert not conn.is_connected remove() - await conn.force_disconnect() + conn.force_disconnect() await asyncio.sleep(0) @@ -440,7 +440,7 @@ async def test_plaintext_connection_fails_handshake( assert isinstance(messages[1], DeviceInfoResponse) assert messages[1].name == "m5stackatomproxy" remove() - await conn.force_disconnect() + conn.force_disconnect() await asyncio.sleep(0) @@ -493,7 +493,7 @@ async def test_force_disconnect_fails( assert conn.is_connected with patch.object(protocol, "_writer", side_effect=OSError): - await conn.force_disconnect() + conn.force_disconnect() assert "Failed to send (forced) disconnect request" in caplog.text await asyncio.sleep(0) @@ -737,7 +737,7 @@ async def test_unknown_protobuf_message_type_logged( assert "Skipping unknown message type 16385" in caplog.text assert connection.is_connected - await connection.force_disconnect() + connection.force_disconnect() await asyncio.sleep(0)