From eabc3d421f8fdf69c3b9179b6d4a1b1b1984eb91 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 10 Mar 2024 09:44:03 -1000 Subject: [PATCH] Ensure all command and service calls raise when disconnected (#840) --- aioesphomeapi/client.py | 9 ++++--- tests/test_client.py | 52 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 63d23c5..dcb82df 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -928,6 +928,7 @@ class APIClient: tilt: float | None = None, stop: bool = False, ) -> None: + connection = self._get_connection() req = CoverCommandRequest(key=key) apiv = self.api_version if TYPE_CHECKING: @@ -951,7 +952,7 @@ class APIClient: elif position == 0.0: req.legacy_command = LegacyCoverCommand.CLOSE req.has_legacy_command = True - self._get_connection().send_message(req) + connection.send_message(req) def fan_command( self, @@ -1058,6 +1059,7 @@ class APIClient: custom_preset: str | None = None, target_humidity: float | None = None, ) -> None: + connection = self._get_connection() req = ClimateCommandRequest(key=key) if mode is not None: req.has_mode = True @@ -1096,7 +1098,7 @@ class APIClient: if target_humidity is not None: req.has_target_humidity = True req.target_humidity = target_humidity - self._get_connection().send_message(req) + connection.send_message(req) def number_command(self, key: int, state: float) -> None: self._get_connection().send_message(NumberCommandRequest(key=key, state=state)) @@ -1172,6 +1174,7 @@ class APIClient: def execute_service( self, service: UserService, data: ExecuteServiceDataType ) -> None: + connection = self._get_connection() req = ExecuteServiceRequest(key=service.key) args = [] apiv = self.api_version @@ -1196,7 +1199,7 @@ class APIClient: # pylint: disable=no-member req.args.extend(args) - self._get_connection().send_message(req) + connection.send_message(req) def _request_image(self, *, single: bool = False, stream: bool = False) -> None: self._get_connection().send_message( diff --git a/tests/test_client.py b/tests/test_client.py index 2c290e5..b8b41fb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2280,3 +2280,55 @@ async def test_api_version_after_connection_closed( assert client.api_version == APIVersion(1, 9) await client.disconnect(force=True) assert client.api_version is None + + +@pytest.mark.asyncio +async def test_calls_after_connection_closed( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test calls after connection close should raise APIConnectionError.""" + client, connection, transport, protocol = api_client + assert client.api_version == APIVersion(1, 9) + await client.disconnect(force=True) + assert client.api_version is None + service = UserService( + name="my_service", + key=1, + args=[], + ) + with pytest.raises(APIConnectionError): + client.execute_service(service, {}) + for method in ( + client.button_command, + client.climate_command, + client.cover_command, + client.fan_command, + client.light_command, + client.media_player_command, + client.siren_command, + ): + with pytest.raises(APIConnectionError): + await method(1) + + with pytest.raises(APIConnectionError): + await client.alarm_control_panel_command(1, AlarmControlPanelCommand.ARM_HOME) + + with pytest.raises(APIConnectionError): + await client.date_command(1, 1, 1, 1) + + with pytest.raises(APIConnectionError): + await client.lock_command(1, LockCommand.LOCK) + + with pytest.raises(APIConnectionError): + await client.number_command(1, 1) + + with pytest.raises(APIConnectionError): + await client.select_command(1, "1") + + with pytest.raises(APIConnectionError): + await client.switch_command(1, True) + + with pytest.raises(APIConnectionError): + await client.text_command(1, "1")