Add test for unexpected hello responses (#712)
This commit is contained in:
parent
79686bf729
commit
66e654084b
|
@ -419,7 +419,7 @@ class APIConnection:
|
||||||
self.log_name,
|
self.log_name,
|
||||||
api_version.major,
|
api_version.major,
|
||||||
)
|
)
|
||||||
raise APIConnectionError("Incompatible API version.")
|
raise APIConnectionError(f"Incompatible API version ({api_version}).")
|
||||||
|
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
expected_name = self._params.expected_name
|
expected_name = self._params.expected_name
|
||||||
|
|
|
@ -129,10 +129,14 @@ async def connect_client(
|
||||||
await client.finish_connection(login=login)
|
await client.finish_connection(login=login)
|
||||||
|
|
||||||
|
|
||||||
def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
def send_plaintext_hello(
|
||||||
|
protocol: APIPlaintextFrameHelper,
|
||||||
|
major: int | None = None,
|
||||||
|
minor: int | None = None,
|
||||||
|
) -> None:
|
||||||
hello_response: message.Message = HelloResponse()
|
hello_response: message.Message = HelloResponse()
|
||||||
hello_response.api_version_major = 1
|
hello_response.api_version_major = 1 if major is None else major
|
||||||
hello_response.api_version_minor = 9
|
hello_response.api_version_minor = 9 if minor is None else minor
|
||||||
hello_response.name = "fake"
|
hello_response.name = "fake"
|
||||||
protocol.data_received(generate_plaintext_packet(hello_response))
|
protocol.data_received(generate_plaintext_packet(hello_response))
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -88,59 +90,82 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection:
|
||||||
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
|
||||||
|
connection_params = replace(connection_params, expected_name="test")
|
||||||
|
return PatchableAPIConnection(connection_params, on_stop, True, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mock_transport_protocol(
|
||||||
|
transport: asyncio.Transport,
|
||||||
|
connected: asyncio.Event,
|
||||||
|
create_func: Callable[[], APIPlaintextFrameHelper],
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
|
||||||
|
protocol: APIPlaintextFrameHelper = create_func()
|
||||||
|
protocol.connection_made(transport)
|
||||||
|
connected.set()
|
||||||
|
return transport, protocol
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
|
||||||
async def plaintext_connect_task_no_login(
|
async def plaintext_connect_task_no_login(
|
||||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
protocol: APIPlaintextFrameHelper | None = None
|
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
connected = asyncio.Event()
|
connected = asyncio.Event()
|
||||||
|
|
||||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
|
||||||
nonlocal protocol
|
|
||||||
protocol = create_func()
|
|
||||||
protocol.connection_made(transport)
|
|
||||||
connected.set()
|
|
||||||
return transport, protocol
|
|
||||||
|
|
||||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
yield conn, transport, protocol, connect_task
|
yield conn, transport, conn._frame_helper, connect_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
|
||||||
|
async def plaintext_connect_task_no_login_with_expected_name(
|
||||||
|
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
|
||||||
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
|
transport = MagicMock()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
|
||||||
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
|
event_loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
|
):
|
||||||
|
connect_task = asyncio.create_task(
|
||||||
|
connect(conn_with_expected_name, login=False)
|
||||||
|
)
|
||||||
|
await connected.wait()
|
||||||
|
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
|
||||||
async def plaintext_connect_task_with_login(
|
async def plaintext_connect_task_with_login(
|
||||||
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
|
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
|
||||||
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
protocol: APIPlaintextFrameHelper | None = None
|
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
connected = asyncio.Event()
|
connected = asyncio.Event()
|
||||||
|
|
||||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
|
||||||
nonlocal protocol
|
|
||||||
protocol = create_func()
|
|
||||||
protocol.connection_made(transport)
|
|
||||||
connected.set()
|
|
||||||
return transport, protocol
|
|
||||||
|
|
||||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
event_loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
|
connect_task = asyncio.create_task(connect(conn_with_password, login=True))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
yield conn_with_password, transport, protocol, connect_task
|
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(name="api_client")
|
@pytest_asyncio.fixture(name="api_client")
|
||||||
async def api_client(
|
async def api_client(
|
||||||
conn: APIConnection, resolve_host, socket_socket, event_loop
|
conn: APIConnection, resolve_host, socket_socket, event_loop
|
||||||
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
protocol: APIPlaintextFrameHelper | None = None
|
protocol: APIPlaintextFrameHelper | None = None
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
connected = asyncio.Event()
|
connected = asyncio.Event()
|
||||||
|
@ -150,18 +175,14 @@ async def api_client(
|
||||||
password=None,
|
password=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
|
||||||
nonlocal protocol
|
|
||||||
protocol = create_func()
|
|
||||||
protocol.connection_made(transport)
|
|
||||||
connected.set()
|
|
||||||
return transport, protocol
|
|
||||||
|
|
||||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
event_loop,
|
||||||
|
"create_connection",
|
||||||
|
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||||
):
|
):
|
||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
|
protocol = conn._frame_helper
|
||||||
send_plaintext_hello(protocol)
|
send_plaintext_hello(protocol)
|
||||||
client._connection = conn
|
client._connection = conn
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
|
@ -477,6 +477,42 @@ async def test_connect_correct_password(
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_wrong_version(
|
||||||
|
plaintext_connect_task_with_login: tuple[
|
||||||
|
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||||
|
|
||||||
|
send_plaintext_hello(protocol, 3, 2)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
|
||||||
|
with pytest.raises(APIConnectionError, match="Incompatible API version"):
|
||||||
|
await connect_task
|
||||||
|
|
||||||
|
assert conn.is_connected is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_wrong_name(
|
||||||
|
plaintext_connect_task_expected_name: tuple[
|
||||||
|
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
conn, transport, protocol, connect_task = plaintext_connect_task_expected_name
|
||||||
|
send_plaintext_hello(protocol)
|
||||||
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
APIConnectionError,
|
||||||
|
match="Expected 'test' but server sent a different name: 'fake'",
|
||||||
|
):
|
||||||
|
await connect_task
|
||||||
|
|
||||||
|
assert conn.is_connected is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_force_disconnect_fails(
|
async def test_force_disconnect_fails(
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
|
Loading…
Reference in New Issue