Ensure all command and service calls raise when disconnected (#840)

This commit is contained in:
J. Nick Koston 2024-03-10 09:44:03 -10:00 committed by GitHub
parent a3009097a8
commit eabc3d421f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 3 deletions

View File

@ -928,6 +928,7 @@ class APIClient:
tilt: float | None = None,
stop: bool = False,
) -> None:
connection = self._get_connection()
req = CoverCommandRequest(key=key)
apiv = self.api_version
if TYPE_CHECKING:
@ -951,7 +952,7 @@ class APIClient:
elif position == 0.0:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
self._get_connection().send_message(req)
connection.send_message(req)
def fan_command(
self,
@ -1058,6 +1059,7 @@ class APIClient:
custom_preset: str | None = None,
target_humidity: float | None = None,
) -> None:
connection = self._get_connection()
req = ClimateCommandRequest(key=key)
if mode is not None:
req.has_mode = True
@ -1096,7 +1098,7 @@ class APIClient:
if target_humidity is not None:
req.has_target_humidity = True
req.target_humidity = target_humidity
self._get_connection().send_message(req)
connection.send_message(req)
def number_command(self, key: int, state: float) -> None:
self._get_connection().send_message(NumberCommandRequest(key=key, state=state))
@ -1172,6 +1174,7 @@ class APIClient:
def execute_service(
self, service: UserService, data: ExecuteServiceDataType
) -> None:
connection = self._get_connection()
req = ExecuteServiceRequest(key=service.key)
args = []
apiv = self.api_version
@ -1196,7 +1199,7 @@ class APIClient:
# pylint: disable=no-member
req.args.extend(args)
self._get_connection().send_message(req)
connection.send_message(req)
def _request_image(self, *, single: bool = False, stream: bool = False) -> None:
self._get_connection().send_message(

View File

@ -2280,3 +2280,55 @@ async def test_api_version_after_connection_closed(
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None
@pytest.mark.asyncio
async def test_calls_after_connection_closed(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test calls after connection close should raise APIConnectionError."""
client, connection, transport, protocol = api_client
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None
service = UserService(
name="my_service",
key=1,
args=[],
)
with pytest.raises(APIConnectionError):
client.execute_service(service, {})
for method in (
client.button_command,
client.climate_command,
client.cover_command,
client.fan_command,
client.light_command,
client.media_player_command,
client.siren_command,
):
with pytest.raises(APIConnectionError):
await method(1)
with pytest.raises(APIConnectionError):
await client.alarm_control_panel_command(1, AlarmControlPanelCommand.ARM_HOME)
with pytest.raises(APIConnectionError):
await client.date_command(1, 1, 1, 1)
with pytest.raises(APIConnectionError):
await client.lock_command(1, LockCommand.LOCK)
with pytest.raises(APIConnectionError):
await client.number_command(1, 1)
with pytest.raises(APIConnectionError):
await client.select_command(1, "1")
with pytest.raises(APIConnectionError):
await client.switch_command(1, True)
with pytest.raises(APIConnectionError):
await client.text_command(1, "1")