diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 4245c7c..232550c 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -537,25 +537,34 @@ class APIConnection: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError as APIConnectionError self._cleanup() - if not isinstance(ex, APIConnectionError): - cause: Exception | None = None - if isinstance(ex, CancelledError): - err_str = "Starting connection cancelled" - if self._fatal_exception: - err_str += f" due to fatal exception: {self._fatal_exception}" - cause = self._fatal_exception - else: - err_str = str(ex) or type(ex).__name__ - new_exc = APIConnectionError( - f"Error while starting connection: {err_str}" - ) - new_exc.__cause__ = cause or ex - raise new_exc - raise ex + raise self._wrap_fatal_connection_exception("starting", ex) finally: self._start_connect_task = None self._set_connection_state(ConnectionState.SOCKET_OPENED) + def _wrap_fatal_connection_exception( + self, action: str, ex: BaseException + ) -> APIConnectionError: + """Ensure a fatal exception is wrapped as as an APIConnectionError.""" + if isinstance(ex, APIConnectionError): + return ex + cause: BaseException | None = None + if isinstance(ex, CancelledError): + err_str = f"{action.title()} connection cancelled" + if self._fatal_exception: + err_str += f" due to fatal exception: {self._fatal_exception}" + cause = self._fatal_exception + else: + err_str = str(ex) or type(ex).__name__ + cause = ex + if isinstance(self._fatal_exception, APIConnectionError): + klass = type(self._fatal_exception) + else: + klass = APIConnectionError + new_exc = klass(f"Error while {action} connection: {err_str}") + new_exc.__cause__ = cause or ex + return new_exc + async def _do_finish_connect(self, login: bool) -> None: """Finish the connection process.""" in_do_connect.set(True) @@ -585,22 +594,7 @@ class APIConnection: # If the task was cancelled, we need to clean up the connection # and raise the CancelledError as APIConnectionError self._cleanup() - if not isinstance(ex, APIConnectionError): - cause: Exception | None = None - if isinstance(ex, CancelledError): - err_str = "Finishing connection cancelled" - if self._fatal_exception: - err_str += f" due to fatal exception: {self._fatal_exception}" - cause = self._fatal_exception - else: - err_str = str(ex) or type(ex).__name__ - cause = ex - new_exc = APIConnectionError( - f"Error while finishing connection: {err_str}" - ) - new_exc.__cause__ = cause or ex - raise new_exc - raise ex + raise self._wrap_fatal_connection_exception("finishing", ex) finally: self._finish_connect_task = None self._set_connection_state(ConnectionState.CONNECTED) diff --git a/tests/common.py b/tests/common.py index 19855ff..362f14c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -117,3 +117,14 @@ def send_plaintext_connect_response( def send_ping_response(protocol: APIPlaintextFrameHelper) -> None: ping_response: message.Message = PingResponse() protocol.data_received(generate_plaintext_packet(ping_response)) + + +def get_mock_protocol(conn: APIConnection): + protocol = APIPlaintextFrameHelper( + connection=conn, + client_info="mock", + log_name="mock_device", + ) + transport = MagicMock() + protocol.connection_made(transport) + return protocol diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index ae838fc..b7ce35b 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -4,7 +4,7 @@ import asyncio import base64 from datetime import timedelta from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from noise.connection import NoiseConnection # type: ignore[import-untyped] @@ -20,16 +20,17 @@ from aioesphomeapi._frame_helper.plain_text import ( _cached_varuint_to_bytes as cached_varuint_to_bytes, ) from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes +from aioesphomeapi.connection import ConnectionState from aioesphomeapi.core import ( + APIConnectionError, BadNameAPIError, HandshakeAPIError, InvalidEncryptionKeyAPIError, ProtocolAPIError, - SocketAPIError, SocketClosedAPIError, ) -from .common import async_fire_time_changed, utcnow +from .common import async_fire_time_changed, get_mock_protocol, utcnow PREAMBLE = b"\x00" @@ -385,6 +386,10 @@ def test_bytes_to_varuint(val, encoded): assert cached_bytes_to_varuint(encoded) == val +def test_bytes_to_varuint_invalid(): + assert bytes_to_varuint(b"\xFF") is None + + @pytest.mark.asyncio async def test_noise_frame_helper_handshake_failure(): """Test the noise frame helper handshake failure.""" @@ -568,3 +573,52 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): with pytest.raises(ProtocolAPIError, match="Connection closed"): helper.data_received(encrypted_header + encrypted_payload) + + +@pytest.mark.asyncio +async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): + loop = asyncio.get_event_loop() + protocol = get_mock_protocol(conn) + with patch.object(loop, "create_connection") as create_connection: + create_connection.return_value = (MagicMock(), protocol) + + conn._socket = MagicMock() + await conn._connect_init_frame_helper() + loop.call_soon(conn._frame_helper._ready_future.set_result, None) + conn.connection_state = ConnectionState.CONNECTED + + task = asyncio.create_task(conn._connect_hello_login(login=True)) + await asyncio.sleep(0) + # The preamble should be \x00 but we send \x09 + protocol.data_received(b"\x09\x00\x00") + + with pytest.raises(ProtocolAPIError): + await task + + +@pytest.mark.asyncio +async def test_eof_received_closes_connection( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + assert protocol.eof_received() is False + assert conn.is_connected is False + with pytest.raises(SocketClosedAPIError, match="EOF received"): + await connect_task + + +@pytest.mark.asyncio +async def test_connection_lost_closes_connection_and_logs( + caplog: pytest.LogCaptureFixture, + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + protocol.connection_lost(OSError("original message")) + assert conn.is_connected is False + assert "original message" in caplog.text + with pytest.raises(APIConnectionError, match="original message"): + await connect_task diff --git a/tests/test_client.py b/tests/test_client.py index d50b001..00a7e1c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -58,12 +58,7 @@ from aioesphomeapi.model import ( ) from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState -from .common import ( - PROTO_TO_MESSAGE_TYPE, - Estr, - generate_plaintext_packet, - get_mock_zeroconf, -) +from .common import Estr, generate_plaintext_packet, get_mock_zeroconf @pytest.fixture diff --git a/tests/test_connection.py b/tests/test_connection.py index f38b8c5..6fe816e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -31,6 +31,7 @@ from .common import ( async_fire_time_changed, connect, generate_plaintext_packet, + get_mock_protocol, send_ping_response, send_plaintext_connect_response, send_plaintext_hello, @@ -41,17 +42,6 @@ from .conftest import KEEP_ALIVE_INTERVAL KEEP_ALIVE_TIMEOUT_RATIO = 4.5 -def _get_mock_protocol(conn: APIConnection): - protocol = APIPlaintextFrameHelper( - connection=conn, - client_info="mock", - log_name="mock_device", - ) - transport = MagicMock() - protocol.connection_made(transport) - return protocol - - @pytest.mark.asyncio async def test_connect( plaintext_connect_task_no_login: tuple[ @@ -152,7 +142,7 @@ async def test_disconnect_when_not_fully_connected( @pytest.mark.asyncio async def test_requires_encryption_propagates(conn: APIConnection): loop = asyncio.get_event_loop() - protocol = _get_mock_protocol(conn) + protocol = get_mock_protocol(conn) with patch.object(loop, "create_connection") as create_connection: create_connection.return_value = (MagicMock(), protocol) @@ -357,7 +347,7 @@ async def test_plaintext_connection_fails_handshake( """ loop = asyncio.get_event_loop() exception, raised_exception = exception_map - protocol = _get_mock_protocol(conn) + protocol = get_mock_protocol(conn) messages = [] protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index aec65ca..e393675 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -14,7 +14,6 @@ from aioesphomeapi.connection import APIConnection from aioesphomeapi.log_runner import async_run from .common import ( - PROTO_TO_MESSAGE_TYPE, Estr, generate_plaintext_packet, get_mock_async_zeroconf,