From 8c7204464f746790aa6d8f10e6b85bcb6d2f6469 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 9 Sep 2024 15:10:26 -0500 Subject: [PATCH] Add handle_announcement_finished callback (#954) --- aioesphomeapi/client.py | 26 ++++++++++++++++++++++ tests/test_client.py | 49 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 04963c8..0816120 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -1289,6 +1289,13 @@ class APIClient: ] | None ) = None, + handle_announcement_finished: ( + Callable[ + [VoiceAssistantAnnounceFinishedModel], + Coroutine[Any, Any, None], + ] + | None + ) = None, ) -> Callable[[], None]: """Subscribes to voice assistant messages from the device. @@ -1297,6 +1304,10 @@ class APIClient: handle_stop: called when the device has stopped sending audio data and the pipeline should be closed. + handle_audio: called when a chunk of audio is sent from the device. + + handle_announcement_finished: called when a VoiceAssistantAnnounceFinished message is sent from the device. + Returns a callback to unsubscribe. """ connection = self._get_connection() @@ -1362,6 +1373,21 @@ class APIClient: ) ) + if handle_announcement_finished is not None: + + def _on_voice_assistant_announcement_finished( + msg: VoiceAssistantAnnounceFinished, + ) -> None: + finished = VoiceAssistantAnnounceFinishedModel.from_pb(msg) + self._create_background_task(handle_announcement_finished(finished)) + + remove_callbacks.append( + connection.add_message_callback( + _on_voice_assistant_announcement_finished, + (VoiceAssistantAnnounceFinished,), + ) + ) + def unsub() -> None: nonlocal start_task diff --git a/tests/test_client.py b/tests/test_client.py index 32f2117..2cb8bb4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2588,6 +2588,55 @@ async def test_send_voice_assistant_announcement_await_response( assert isinstance(finished, VoiceAssistantAnnounceFinishedModel) +@pytest.mark.asyncio +async def test_subscribe_voice_assistant_announcement_finished( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant with handle_announcement_finished.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + done = asyncio.Event() + + async def handle_start( + conversation_id: str, + flags: int, + audio_settings: VoiceAssistantAudioSettings, + wake_word_phrase: str | None, + ) -> int | None: + return 0 + + async def handle_stop() -> None: + pass + + async def handle_announcement_finished( + finished: VoiceAssistantAnnounceFinishedModel, + ) -> None: + assert finished.success + done.set() + + unsub = client.subscribe_voice_assistant( + handle_start=handle_start, + handle_stop=handle_stop, + handle_announcement_finished=handle_announcement_finished, + ) + send.assert_called_once_with( + SubscribeVoiceAssistantRequest(subscribe=True, flags=0) + ) + send.reset_mock() + response: message.Message = VoiceAssistantAnnounceFinished(success=True) + mock_data_received(protocol, generate_plaintext_packet(response)) + + await asyncio.wait_for(done.wait(), 1) + + await client.disconnect(force=True) + # Ensure abort callback is a no-op after disconnect + # and does not raise + unsub() + assert len(send.mock_calls) == 0 + + @pytest.mark.asyncio async def test_api_version_after_connection_closed( api_client: tuple[