Add test for unexpected hello responses (#712)

This commit is contained in:
J. Nick Koston 2023-11-25 09:58:30 -06:00 committed by GitHub
parent 79686bf729
commit 66e654084b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 34 deletions

View File

@ -419,7 +419,7 @@ class APIConnection:
self.log_name,
api_version.major,
)
raise APIConnectionError("Incompatible API version.")
raise APIConnectionError(f"Incompatible API version ({api_version}).")
self.api_version = api_version
expected_name = self._params.expected_name

View File

@ -129,10 +129,14 @@ async def connect_client(
await client.finish_connection(login=login)
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
def send_plaintext_hello(
protocol: APIPlaintextFrameHelper,
major: int | None = None,
minor: int | None = None,
) -> None:
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.api_version_major = 1 if major is None else major
hello_response.api_version_minor = 9 if minor is None else minor
hello_response.name = "fake"
protocol.data_received(generate_plaintext_packet(hello_response))

View File

@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio
import socket
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, patch
import pytest
@ -88,59 +90,82 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(connection_params, on_stop, True, None)
@pytest.fixture
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
connection_params = replace(connection_params, expected_name="test")
return PatchableAPIConnection(connection_params, on_stop, True, None)
def _create_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
create_func: Callable[[], APIPlaintextFrameHelper],
**kwargs,
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login(
conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
yield conn, transport, protocol, connect_task
yield conn, transport, conn._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch.object(event_loop, "sock_connect"), patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(
connect(conn_with_expected_name, login=False)
)
await connected.wait()
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
await connected.wait()
yield conn_with_password, transport, protocol, connect_task
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
@pytest_asyncio.fixture(name="api_client")
async def api_client(
conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
@ -150,18 +175,14 @@ async def api_client(
password=None,
)
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol = conn._frame_helper
send_plaintext_hello(protocol)
client._connection = conn
await connect_task

View File

@ -477,6 +477,42 @@ async def test_connect_correct_password(
assert conn.is_connected
@pytest.mark.asyncio
async def test_connect_wrong_version(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol, 3, 2)
send_plaintext_connect_response(protocol, False)
with pytest.raises(APIConnectionError, match="Incompatible API version"):
await connect_task
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_connect_wrong_name(
plaintext_connect_task_expected_name: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_expected_name
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
with pytest.raises(
APIConnectionError,
match="Expected 'test' but server sent a different name: 'fake'",
):
await connect_task
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_force_disconnect_fails(
caplog: pytest.LogCaptureFixture,