From ccf2f1f245779fa478b652932dc7226b402f8027 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 21 Nov 2023 14:01:58 +0100 Subject: [PATCH] Add tests for ping responses (#650) --- tests/common.py | 7 +++- tests/conftest.py | 4 ++- tests/test_connection.py | 78 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/tests/common.py b/tests/common.py index 248659f..19855ff 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,7 +12,7 @@ from zeroconf.asyncio import AsyncZeroconf from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes -from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse +from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse, PingResponse from aioesphomeapi.connection import APIConnection from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO @@ -112,3 +112,8 @@ def send_plaintext_connect_response( connect_response: message.Message = ConnectResponse() connect_response.invalid_password = invalid_password protocol.data_received(generate_plaintext_packet(connect_response)) + + +def send_ping_response(protocol: APIPlaintextFrameHelper) -> None: + ping_response: message.Message = PingResponse() + protocol.data_received(generate_plaintext_packet(ping_response)) diff --git a/tests/conftest.py b/tests/conftest.py index 1ef58bb..1aa48c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ from aioesphomeapi.zeroconf import ZeroconfManager from .common import connect, get_mock_async_zeroconf, send_plaintext_hello +KEEP_ALIVE_INTERVAL = 15.0 + @pytest.fixture def async_zeroconf(): @@ -47,7 +49,7 @@ def connection_params() -> ConnectionParams: port=6052, password=None, client_info="Tests client", - keepalive=15.0, + keepalive=KEEP_ALIVE_INTERVAL, zeroconf_manager=ZeroconfManager(), noise_psk=None, expected_name=None, diff --git a/tests/test_connection.py b/tests/test_connection.py index d1f4b6e..f38b8c5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import Coroutine from datetime import timedelta from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest @@ -31,10 +31,14 @@ from .common import ( async_fire_time_changed, connect, generate_plaintext_packet, + send_ping_response, send_plaintext_connect_response, send_plaintext_hello, utcnow, ) +from .conftest import KEEP_ALIVE_INTERVAL + +KEEP_ALIVE_TIMEOUT_RATIO = 4.5 def _get_mock_protocol(conn: APIConnection): @@ -543,3 +547,75 @@ async def test_disconnect_fails_to_send_response( # Wait one loop iteration for the disconnect to be processed await asyncio.sleep(0) assert expected_disconnect is True + + +@pytest.mark.asyncio +async def test_ping_disconnects_after_no_responses( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + + ping_request_bytes = b"\x00\x00\x07" + + assert conn.is_connected + transport.reset_mock() + expected_calls = [] + start_time = utcnow() + max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO) + for count in range(1, max_pings_to_disconnect_after + 1): + async_fire_time_changed( + start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count) + ) + assert transport.write.call_count == count + expected_calls.append(call(ping_request_bytes)) + assert transport.write.mock_calls == expected_calls + + assert conn.is_connected is True + + # We should disconnect once we reach more than 4 missed pings + async_fire_time_changed( + start_time + + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1)) + ) + assert transport.write.call_count == max_pings_to_disconnect_after + + assert conn.is_connected is False + + +@pytest.mark.asyncio +async def test_ping_does_not_disconnect_if_we_get_responses( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + ping_request_bytes = b"\x00\x00\x07" + + assert conn.is_connected + transport.reset_mock() + start_time = utcnow() + max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO) + for count in range(1, max_pings_to_disconnect_after + 2): + async_fire_time_changed( + start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count) + ) + send_ping_response(protocol) + + # We should only send 1 ping request if we are getting responses + assert transport.write.call_count == 1 + assert transport.write.mock_calls == [call(ping_request_bytes)] + + # We should disconnect if we are getting ping responses + assert conn.is_connected is True