diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index f62b60e..a30a348 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -20,7 +20,8 @@ from google.protobuf import message import aioesphomeapi.host_resolver as hr -from ._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper +from ._frame_helper.noise import APINoiseFrameHelper +from ._frame_helper.plain_text import APIPlaintextFrameHelper from .api_pb2 import ( # type: ignore ConnectRequest, ConnectResponse, diff --git a/tests/conftest.py b/tests/conftest.py index 1aa48c0..bd956b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,13 +57,33 @@ def connection_params() -> ConnectionParams: @pytest.fixture -def conn(connection_params) -> APIConnection: - async def on_stop(expected_disconnect: bool) -> None: - pass +def noise_connection_params() -> ConnectionParams: + return ConnectionParams( + address="fake.address", + port=6052, + password=None, + client_info="Tests client", + keepalive=KEEP_ALIVE_INTERVAL, + zeroconf_manager=ZeroconfManager(), + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="test", + ) + +async def on_stop(expected_disconnect: bool) -> None: + pass + + +@pytest.fixture +def conn(connection_params: ConnectionParams) -> APIConnection: return APIConnection(connection_params, on_stop) +@pytest.fixture +def noise_conn(noise_connection_params: ConnectionParams) -> APIConnection: + return APIConnection(noise_connection_params, on_stop) + + @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, socket_socket, event_loop diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index b7ce35b..e68df21 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -596,6 +596,31 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): await task +@pytest.mark.asyncio +async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> None: + loop = asyncio.get_event_loop() + transport = MagicMock() + protocol: APINoiseFrameHelper | None = None + + async def _create_connection(create, sock, *args, **kwargs): + nonlocal protocol + protocol = create() + protocol.connection_made(transport) + return transport, protocol + + with patch.object(loop, "create_connection", side_effect=_create_connection): + task = asyncio.create_task(noise_conn._connect_init_frame_helper()) + await asyncio.sleep(0) + + assert protocol is not None + assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper) + + protocol.data_received(b"\x00\x00\x00") + + with pytest.raises(ProtocolAPIError, match="Marker byte invalid"): + await task + + @pytest.mark.asyncio async def test_eof_received_closes_connection( plaintext_connect_task_with_login: tuple[