Add tests for ping responses (#650)

This commit is contained in:
J. Nick Koston 2023-11-21 14:01:58 +01:00 committed by GitHub
parent 298aa01b00
commit ccf2f1f245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse, PingResponse
from aioesphomeapi.connection import APIConnection from aioesphomeapi.connection import APIConnection
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
@ -112,3 +112,8 @@ def send_plaintext_connect_response(
connect_response: message.Message = ConnectResponse() connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = invalid_password connect_response.invalid_password = invalid_password
protocol.data_received(generate_plaintext_packet(connect_response)) protocol.data_received(generate_plaintext_packet(connect_response))
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
ping_response: message.Message = PingResponse()
protocol.data_received(generate_plaintext_packet(ping_response))

View File

@ -16,6 +16,8 @@ from aioesphomeapi.zeroconf import ZeroconfManager
from .common import connect, get_mock_async_zeroconf, send_plaintext_hello from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
KEEP_ALIVE_INTERVAL = 15.0
@pytest.fixture @pytest.fixture
def async_zeroconf(): def async_zeroconf():
@ -47,7 +49,7 @@ def connection_params() -> ConnectionParams:
port=6052, port=6052,
password=None, password=None,
client_info="Tests client", client_info="Tests client",
keepalive=15.0, keepalive=KEEP_ALIVE_INTERVAL,
zeroconf_manager=ZeroconfManager(), zeroconf_manager=ZeroconfManager(),
noise_psk=None, noise_psk=None,
expected_name=None, expected_name=None,

View File

@ -4,7 +4,7 @@ import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest import pytest
@ -31,10 +31,14 @@ from .common import (
async_fire_time_changed, async_fire_time_changed,
connect, connect,
generate_plaintext_packet, generate_plaintext_packet,
send_ping_response,
send_plaintext_connect_response, send_plaintext_connect_response,
send_plaintext_hello, send_plaintext_hello,
utcnow, utcnow,
) )
from .conftest import KEEP_ALIVE_INTERVAL
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
def _get_mock_protocol(conn: APIConnection): def _get_mock_protocol(conn: APIConnection):
@ -543,3 +547,75 @@ async def test_disconnect_fails_to_send_response(
# Wait one loop iteration for the disconnect to be processed # Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0) await asyncio.sleep(0)
assert expected_disconnect is True assert expected_disconnect is True
@pytest.mark.asyncio
async def test_ping_disconnects_after_no_responses(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
ping_request_bytes = b"\x00\x00\x07"
assert conn.is_connected
transport.reset_mock()
expected_calls = []
start_time = utcnow()
max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
for count in range(1, max_pings_to_disconnect_after + 1):
async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
)
assert transport.write.call_count == count
expected_calls.append(call(ping_request_bytes))
assert transport.write.mock_calls == expected_calls
assert conn.is_connected is True
# We should disconnect once we reach more than 4 missed pings
async_fire_time_changed(
start_time
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
)
assert transport.write.call_count == max_pings_to_disconnect_after
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_ping_does_not_disconnect_if_we_get_responses(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await connect_task
ping_request_bytes = b"\x00\x00\x07"
assert conn.is_connected
transport.reset_mock()
start_time = utcnow()
max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
for count in range(1, max_pings_to_disconnect_after + 2):
async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
)
send_ping_response(protocol)
# We should only send 1 ping request if we are getting responses
assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_request_bytes)]
# We should disconnect if we are getting ping responses
assert conn.is_connected is True