From 79686bf7298b367b9c7d59c5b306e90b6ba69ee2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 25 Nov 2023 09:33:43 -0600 Subject: [PATCH] Fix client connection code swallowing unhandled exceptions as debug logging (#711) --- aioesphomeapi/client.py | 15 ++------------- aioesphomeapi/connection.py | 8 +++++++- aioesphomeapi/core.py | 4 ++++ tests/test_client.py | 22 ++++++++++++++++++++++ 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 23d8bd0..9d515c2 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -85,7 +85,6 @@ from .core import ( APIConnectionError, BluetoothGATTAPIError, TimeoutAPIError, - UnhandledAPIConnectionError, to_human_readable_address, ) from .model import ( @@ -324,14 +323,9 @@ class APIClient: try: await self._connection.start_connection() - except APIConnectionError: + except Exception: self._connection = None raise - except Exception as e: - self._connection = None - raise UnhandledAPIConnectionError( - f"Unexpected error while connecting to {self.log_name}: {e}" - ) from e # If we resolved the address, we should set the log name now if self._connection.resolved_addr_info: self._set_log_name() @@ -345,14 +339,9 @@ class APIClient: assert self._connection is not None try: await self._connection.finish_connection(login=login) - except APIConnectionError: + except Exception: self._connection = None raise - except Exception as e: - self._connection = None - raise UnhandledAPIConnectionError( - f"Unexpected error while connecting to {self.log_name}: {e}" - ) from e if received_name := self._connection.received_name: self._set_name_from_device(received_name) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 8c1bfcc..ad8bbb4 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -35,6 +35,7 @@ from .api_pb2 import ( # type: ignore ) from .core import ( MESSAGE_TYPE_TO_PROTO, + APIConnectionCancelledError, APIConnectionError, BadNameAPIError, ConnectionNotEstablishedAPIError, @@ -46,6 +47,7 @@ from .core import ( ResolveAPIError, SocketAPIError, TimeoutAPIError, + UnhandledAPIConnectionError, ) from .model import APIVersion from .zeroconf import ZeroconfManager @@ -540,8 +542,12 @@ class APIConnection: cause = ex if isinstance(self._fatal_exception, APIConnectionError): klass = type(self._fatal_exception) + elif isinstance(ex, CancelledError): + klass = APIConnectionCancelledError + elif isinstance(ex, OSError): + klass = SocketAPIError else: - klass = APIConnectionError + klass = UnhandledAPIConnectionError new_exc = klass(f"Error while {action} connection: {err_str}") new_exc.__cause__ = cause or ex return new_exc diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index 5114bdb..eeb61ef 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -160,6 +160,10 @@ class APIConnectionError(Exception): pass +class APIConnectionCancelledError(APIConnectionError): + pass + + class InvalidAuthAPIError(APIConnectionError): pass diff --git a/tests/test_client.py b/tests/test_client.py index 8d50a72..1333dff 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -60,6 +60,7 @@ from aioesphomeapi.core import ( APIConnectionError, BluetoothGATTAPIError, TimeoutAPIError, + UnhandledAPIConnectionError, ) from aioesphomeapi.model import ( AlarmControlPanelCommand, @@ -95,6 +96,7 @@ from .common import ( get_mock_zeroconf, mock_data_received, ) +from .conftest import PatchableAPIConnection @pytest.fixture @@ -172,6 +174,26 @@ async def test_connect_backwards_compat() -> None: assert mock_finish_connection.mock_calls == [call(False)] +@pytest.mark.asyncio +async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> None: + """Verify finish_connect re-wraps exceptions as UnhandledAPIError.""" + + cli = APIClient("1.2.3.4", 1234, None) + loop = asyncio.get_event_loop() + with patch( + "aioesphomeapi.client.APIConnection", PatchableAPIConnection + ), patch.object(loop, "sock_connect"): + await cli.start_connection() + + with patch.object( + cli._connection, + "send_messages_await_response_complex", + side_effect=Exception("foo"), + ): + with pytest.raises(UnhandledAPIConnectionError, match="foo"): + await cli.finish_connection(False) + + @pytest.mark.asyncio async def test_connect_while_already_connected(auth_client: APIClient) -> None: """Test connecting while already connected raises."""