From e1447dd249452f140de881e07fc4abcf5a20f359 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Nov 2023 13:36:30 +0100 Subject: [PATCH] Improve connection tests for handling pings (#663) --- tests/common.py | 12 +++++- tests/test_connection.py | 80 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/tests/common.py b/tests/common.py index 362f14c..c7f76ad 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,7 +12,12 @@ from zeroconf.asyncio import AsyncZeroconf from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes -from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse, PingResponse +from aioesphomeapi.api_pb2 import ( + ConnectResponse, + HelloResponse, + PingRequest, + PingResponse, +) from aioesphomeapi.connection import APIConnection from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO @@ -119,6 +124,11 @@ def send_ping_response(protocol: APIPlaintextFrameHelper) -> None: protocol.data_received(generate_plaintext_packet(ping_response)) +def send_ping_request(protocol: APIPlaintextFrameHelper) -> None: + ping_request: message.Message = PingRequest() + protocol.data_received(generate_plaintext_packet(ping_request)) + + def get_mock_protocol(conn: APIConnection): protocol = APIPlaintextFrameHelper( connection=conn, diff --git a/tests/test_connection.py b/tests/test_connection.py index 8475115..16fc061 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -33,6 +33,7 @@ from .common import ( connect, generate_plaintext_packet, get_mock_protocol, + send_ping_request, send_ping_response, send_plaintext_connect_response, send_plaintext_hello, @@ -540,6 +541,62 @@ async def test_disconnect_fails_to_send_response( assert expected_disconnect is True +@pytest.mark.asyncio +async def test_disconnect_success_case( + connection_params: ConnectionParams, + event_loop: asyncio.AbstractEventLoop, + resolve_host, + socket_socket, +) -> None: + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + client = APIClient( + address="mydevice.local", + port=6052, + password=None, + ) + expected_disconnect = None + + async def _on_stop(_expected_disconnect: bool) -> None: + nonlocal expected_disconnect + expected_disconnect = _expected_disconnect + + conn = APIConnection(connection_params, _on_stop) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await connected.wait() + send_plaintext_hello(protocol) + client._connection = conn + await connect_task + transport.reset_mock() + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + assert conn.is_connected + + disconnect_request = DisconnectRequest() + protocol.data_received(generate_plaintext_packet(disconnect_request)) + + # Wait one loop iteration for the disconnect to be processed + await asyncio.sleep(0) + assert expected_disconnect is True + assert not conn.is_connected + + @pytest.mark.asyncio async def test_ping_disconnects_after_no_responses( plaintext_connect_task_with_login: tuple[ @@ -616,3 +673,26 @@ def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) """Test that we raise when sending messages before we are connected.""" with pytest.raises(ConnectionNotEstablishedAPIError): conn.send_message(PingRequest()) + + +@pytest.mark.asyncio +async def test_respond_to_ping_request( + caplog: pytest.LogCaptureFixture, + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + assert conn.is_connected + + transport.reset_mock() + send_ping_request(protocol) + # We should respond to ping requests + ping_response_bytes = b"\x00\x00\x08" + assert transport.write.call_count == 1 + assert transport.write.mock_calls == [call(ping_response_bytes)]