diff --git a/tests/test_client.py b/tests/test_client.py index c6a41e3..2bceab0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1833,7 +1833,7 @@ async def test_subscribe_voice_assistant_failure( APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper ], ) -> None: - """Test subscribe_voice_assistant.""" + """Test subscribe_voice_assistant failure.""" client, connection, transport, protocol = api_client send = patch_send(client) starts = [] @@ -1898,6 +1898,63 @@ async def test_subscribe_voice_assistant_failure( assert len(send.mock_calls) == 0 +@pytest.mark.asyncio +async def test_subscribe_voice_assistant_cancels_long_running_handle_start( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant cancels long running tasks on unsub.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + await asyncio.sleep(10) + # Return None to indicate failure + starts.append("never") + return None + + async def handle_stop() -> None: + stops.append(True) + + unsub = await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + unsub() + await asyncio.sleep(0) + assert not stops + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + + @pytest.mark.asyncio async def test_api_version_after_connection_closed( api_client: tuple[