From 66e654084b044e604bd88663ec195e52c17d642f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 25 Nov 2023 09:58:30 -0600 Subject: [PATCH] Add test for unexpected hello responses (#712) --- aioesphomeapi/connection.py | 2 +- tests/common.py | 10 +++-- tests/conftest.py | 81 +++++++++++++++++++++++-------------- tests/test_connection.py | 36 +++++++++++++++++ 4 files changed, 95 insertions(+), 34 deletions(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index ad8bbb4..7784c61 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -419,7 +419,7 @@ class APIConnection: self.log_name, api_version.major, ) - raise APIConnectionError("Incompatible API version.") + raise APIConnectionError(f"Incompatible API version ({api_version}).") self.api_version = api_version expected_name = self._params.expected_name diff --git a/tests/common.py b/tests/common.py index a195838..f4c5b10 100644 --- a/tests/common.py +++ b/tests/common.py @@ -129,10 +129,14 @@ async def connect_client( await client.finish_connection(login=login) -def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None: +def send_plaintext_hello( + protocol: APIPlaintextFrameHelper, + major: int | None = None, + minor: int | None = None, +) -> None: hello_response: message.Message = HelloResponse() - hello_response.api_version_major = 1 - hello_response.api_version_minor = 9 + hello_response.api_version_major = 1 if major is None else major + hello_response.api_version_minor = 9 if minor is None else minor hello_response.name = "fake" protocol.data_received(generate_plaintext_packet(hello_response)) diff --git a/tests/conftest.py b/tests/conftest.py index 555412e..a39e267 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio import socket from dataclasses import replace +from functools import partial +from typing import Callable from unittest.mock import MagicMock, patch import pytest @@ -88,59 +90,82 @@ def noise_conn(connection_params: ConnectionParams) -> APIConnection: return PatchableAPIConnection(connection_params, on_stop, True, None) +@pytest.fixture +def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection: + connection_params = replace(connection_params, expected_name="test") + return PatchableAPIConnection(connection_params, on_stop, True, None) + + +def _create_mock_transport_protocol( + transport: asyncio.Transport, + connected: asyncio.Event, + create_func: Callable[[], APIPlaintextFrameHelper], + **kwargs, +) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]: + protocol: APIPlaintextFrameHelper = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, socket_socket, event_loop ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: loop = asyncio.get_event_loop() - protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() - yield conn, transport, protocol, connect_task + yield conn, transport, conn._frame_helper, connect_task + + +@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name") +async def plaintext_connect_task_no_login_with_expected_name( + conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop +) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: + transport = MagicMock() + connected = asyncio.Event() + + with patch.object(event_loop, "sock_connect"), patch.object( + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ): + connect_task = asyncio.create_task( + connect(conn_with_expected_name, login=False) + ) + await connected.wait() + yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task @pytest_asyncio.fixture(name="plaintext_connect_task_with_login") async def plaintext_connect_task_with_login( conn_with_password: APIConnection, resolve_host, socket_socket, event_loop ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: - loop = asyncio.get_event_loop() - protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn_with_password, login=True)) await connected.wait() - yield conn_with_password, transport, protocol, connect_task + yield conn_with_password, transport, conn_with_password._frame_helper, connect_task @pytest_asyncio.fixture(name="api_client") async def api_client( conn: APIConnection, resolve_host, socket_socket, event_loop ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: - loop = asyncio.get_event_loop() protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() connected = asyncio.Event() @@ -150,18 +175,14 @@ async def api_client( password=None, ) - def _create_mock_transport_protocol(create_func, **kwargs): - nonlocal protocol - protocol = create_func() - protocol.connection_made(transport) - connected.set() - return transport, protocol - with patch.object(event_loop, "sock_connect"), patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), ): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() + protocol = conn._frame_helper send_plaintext_hello(protocol) client._connection = conn await connect_task diff --git a/tests/test_connection.py b/tests/test_connection.py index 8cb58f1..e26f270 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -477,6 +477,42 @@ async def test_connect_correct_password( assert conn.is_connected +@pytest.mark.asyncio +async def test_connect_wrong_version( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + send_plaintext_hello(protocol, 3, 2) + send_plaintext_connect_response(protocol, False) + + with pytest.raises(APIConnectionError, match="Incompatible API version"): + await connect_task + + assert conn.is_connected is False + + +@pytest.mark.asyncio +async def test_connect_wrong_name( + plaintext_connect_task_expected_name: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_expected_name + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + with pytest.raises( + APIConnectionError, + match="Expected 'test' but server sent a different name: 'fake'", + ): + await connect_task + + assert conn.is_connected is False + + @pytest.mark.asyncio async def test_force_disconnect_fails( caplog: pytest.LogCaptureFixture,