Add test for unexpected hello responses (#712)

This commit is contained in:
J. Nick Koston 2023-11-25 09:58:30 -06:00 committed by GitHub
parent 79686bf729
commit 66e654084b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 34 deletions

View File

@ -419,7 +419,7 @@ class APIConnection:
self.log_name, self.log_name,
api_version.major, api_version.major,
) )
raise APIConnectionError("Incompatible API version.") raise APIConnectionError(f"Incompatible API version ({api_version}).")
self.api_version = api_version self.api_version = api_version
expected_name = self._params.expected_name expected_name = self._params.expected_name

View File

@ -129,10 +129,14 @@ async def connect_client(
await client.finish_connection(login=login) await client.finish_connection(login=login)
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None: def send_plaintext_hello(
protocol: APIPlaintextFrameHelper,
major: int | None = None,
minor: int | None = None,
) -> None:
hello_response: message.Message = HelloResponse() hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1 hello_response.api_version_major = 1 if major is None else major
hello_response.api_version_minor = 9 hello_response.api_version_minor = 9 if minor is None else minor
hello_response.name = "fake" hello_response.name = "fake"
protocol.data_received(generate_plaintext_packet(hello_response)) protocol.data_received(generate_plaintext_packet(hello_response))

View File

@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio import asyncio
import socket import socket
from dataclasses import replace from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -88,59 +90,82 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
return PatchableAPIConnection(connection_params, on_stop, True, None) return PatchableAPIConnection(connection_params, on_stop, True, None)
@pytest.fixture
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
connection_params = replace(connection_params, expected_name="test")
return PatchableAPIConnection(connection_params, on_stop, True, None)
def _create_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
create_func: Callable[[], APIPlaintextFrameHelper],
**kwargs,
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login") @pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login(
conn: APIConnection, resolve_host, socket_socket, event_loop conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
yield conn, transport, protocol, connect_task yield conn, transport, conn._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()
with patch.object(event_loop, "sock_connect"), patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(
connect(conn_with_expected_name, login=False)
)
await connected.wait()
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login") @pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login( async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task(connect(conn_with_password, login=True)) connect_task = asyncio.create_task(connect(conn_with_password, login=True))
await connected.wait() await connected.wait()
yield conn_with_password, transport, protocol, connect_task yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
@pytest_asyncio.fixture(name="api_client") @pytest_asyncio.fixture(name="api_client")
async def api_client( async def api_client(
conn: APIConnection, resolve_host, socket_socket, event_loop conn: APIConnection, resolve_host, socket_socket, event_loop
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()
connected = asyncio.Event() connected = asyncio.Event()
@ -150,18 +175,14 @@ async def api_client(
password=None, password=None,
) )
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
): ):
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
protocol = conn._frame_helper
send_plaintext_hello(protocol) send_plaintext_hello(protocol)
client._connection = conn client._connection = conn
await connect_task await connect_task

View File

@ -477,6 +477,42 @@ async def test_connect_correct_password(
assert conn.is_connected assert conn.is_connected
@pytest.mark.asyncio
async def test_connect_wrong_version(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
send_plaintext_hello(protocol, 3, 2)
send_plaintext_connect_response(protocol, False)
with pytest.raises(APIConnectionError, match="Incompatible API version"):
await connect_task
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_connect_wrong_name(
plaintext_connect_task_expected_name: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_expected_name
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
with pytest.raises(
APIConnectionError,
match="Expected 'test' but server sent a different name: 'fake'",
):
await connect_task
assert conn.is_connected is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_force_disconnect_fails( async def test_force_disconnect_fails(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,