Add test for unexpected hello responses (#712)
This commit is contained in:
parent
79686bf729
commit
66e654084b
|
@ -419,7 +419,7 @@ class APIConnection:
|
|||
self.log_name,
|
||||
api_version.major,
|
||||
)
|
||||
raise APIConnectionError("Incompatible API version.")
|
||||
raise APIConnectionError(f"Incompatible API version ({api_version}).")
|
||||
|
||||
self.api_version = api_version
|
||||
expected_name = self._params.expected_name
|
||||
|
|
|
@ -129,10 +129,14 @@ async def connect_client(
|
|||
await client.finish_connection(login=login)
|
||||
|
||||
|
||||
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
||||
def send_plaintext_hello(
|
||||
protocol: APIPlaintextFrameHelper,
|
||||
major: int | None = None,
|
||||
minor: int | None = None,
|
||||
) -> None:
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.api_version_major = 1 if major is None else major
|
||||
hello_response.api_version_minor = 9 if minor is None else minor
|
||||
hello_response.name = "fake"
|
||||
protocol.data_received(generate_plaintext_packet(hello_response))
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import socket
|
||||
from dataclasses import replace
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -88,59 +90,82 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
|
|||
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
|
||||
connection_params = replace(connection_params, expected_name="test")
|
||||
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||
|
||||
|
||||
def _create_mock_transport_protocol(
|
||||
transport: asyncio.Transport,
|
||||
connected: asyncio.Event,
|
||||
create_func: Callable[[], APIPlaintextFrameHelper],
|
||||
**kwargs,
|
||||
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
protocol: APIPlaintextFrameHelper = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||
async def plaintext_connect_task_no_login(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
yield conn, transport, protocol, connect_task
|
||||
yield conn, transport, conn._frame_helper, connect_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
|
||||
async def plaintext_connect_task_no_login_with_expected_name(
|
||||
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(
|
||||
connect(conn_with_expected_name, login=False)
|
||||
)
|
||||
await connected.wait()
|
||||
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
||||
async def plaintext_connect_task_with_login(
|
||||
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
|
||||
await connected.wait()
|
||||
yield conn_with_password, transport, protocol, connect_task
|
||||
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="api_client")
|
||||
async def api_client(
|
||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
@ -150,18 +175,14 @@ async def api_client(
|
|||
password=None,
|
||||
)
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
event_loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
protocol = conn._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
client._connection = conn
|
||||
await connect_task
|
||||
|
|
|
@ -477,6 +477,42 @@ async def test_connect_correct_password(
|
|||
assert conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_wrong_version(
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
|
||||
send_plaintext_hello(protocol, 3, 2)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
with pytest.raises(APIConnectionError, match="Incompatible API version"):
|
||||
await connect_task
|
||||
|
||||
assert conn.is_connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_wrong_name(
|
||||
plaintext_connect_task_expected_name: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_expected_name
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
with pytest.raises(
|
||||
APIConnectionError,
|
||||
match="Expected 'test' but server sent a different name: 'fake'",
|
||||
):
|
||||
await connect_task
|
||||
|
||||
assert conn.is_connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_disconnect_fails(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
|
|
Loading…
Reference in New Issue