Add tests for ping responses (#650)

This commit is contained in:
J. Nick Koston 2023-11-21 14:01:58 +01:00 committed by GitHub
parent 298aa01b00
commit ccf2f1f245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse, PingResponse
from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
@ -112,3 +112,8 @@ def send_plaintext_connect_response(
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password
protocol.data_received(generate_plaintext_packet(connect_response))
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
ping_response: message.Message = PingResponse()
protocol.data_received(generate_plaintext_packet(ping_response))

View File

@ -16,6 +16,8 @@ from aioesphomeapi.zeroconf import ZeroconfManager
from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
KEEP_ALIVE_INTERVAL = 15.0
@pytest.fixture
def async_zeroconf():
@ -47,7 +49,7 @@ def connection_params() -> ConnectionParams:
port=6052,
password=None,
client_info="Tests client",
keepalive=15.0,
keepalive=KEEP_ALIVE_INTERVAL,
zeroconf_manager=ZeroconfManager(),
noise_psk=None,
expected_name=None,

View File

@ -4,7 +4,7 @@ import asyncio
from collections.abc import Coroutine
from datetime import timedelta
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
@ -31,10 +31,14 @@ from .common import (
async_fire_time_changed,
connect,
generate_plaintext_packet,
send_ping_response,
send_plaintext_connect_response,
send_plaintext_hello,
utcnow,
)
from .conftest import KEEP_ALIVE_INTERVAL
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
def _get_mock_protocol(conn: APIConnection):
@ -543,3 +547,75 @@ async def test_disconnect_fails_to_send_response(
# Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0)
assert expected_disconnect is True
@pytest.mark.asyncio
async def test_ping_disconnects_after_no_responses(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
ping_request_bytes = b"\x00\x00\x07"
assert conn.is_connected
transport.reset_mock()
expected_calls = []
start_time = utcnow()
max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
for count in range(1, max_pings_to_disconnect_after + 1):
async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
)
assert transport.write.call_count == count
expected_calls.append(call(ping_request_bytes))
assert transport.write.mock_calls == expected_calls
assert conn.is_connected is True
# We should disconnect once we reach more than 4 missed pings
async_fire_time_changed(
start_time
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
)
assert transport.write.call_count == max_pings_to_disconnect_after
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_ping_does_not_disconnect_if_we_get_responses(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
ping_request_bytes = b"\x00\x00\x07"
assert conn.is_connected
transport.reset_mock()
start_time = utcnow()
max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
for count in range(1, max_pings_to_disconnect_after + 2):
async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
)
send_ping_response(protocol)
# We should only send 1 ping request if we are getting responses
assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_request_bytes)]
# We should disconnect if we are getting ping responses
assert conn.is_connected is True