Add wake_word_phrase to VoiceAssistant request (#830)

This commit is contained in:
Jesse Hills 2024-02-27 15:53:26 +13:00 committed by GitHub
parent e37200eee9
commit efa92b5cce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 101 additions and 80 deletions

View File

@ -1469,6 +1469,7 @@ message VoiceAssistantRequest {
string conversation_id = 2;
uint32 flags = 3;
VoiceAssistantAudioSettings audio_settings = 4;
string wake_word_phrase = 5;
}
message VoiceAssistantResponse {

File diff suppressed because one or more lines are too long

View File

@ -1212,7 +1212,7 @@ class APIClient:
def subscribe_voice_assistant(
self,
handle_start: Callable[
[str, int, VoiceAssistantAudioSettingsModel],
[str, int, VoiceAssistantAudioSettingsModel, str | None],
Coroutine[Any, Any, int | None],
],
handle_stop: Callable[[], Coroutine[Any, Any, None]],
@ -1244,9 +1244,15 @@ class APIClient:
command = VoiceAssistantCommand.from_pb(msg)
if command.start:
wake_word_phrase: str | None = command.wake_word_phrase
if wake_word_phrase == "":
wake_word_phrase = None
start_task = asyncio.create_task(
handle_start(
command.conversation_id, command.flags, command.audio_settings
command.conversation_id,
command.flags,
command.audio_settings,
wake_word_phrase,
)
)
start_task.add_done_callback(_started)

View File

@ -1115,6 +1115,7 @@ class VoiceAssistantCommand(APIModelBase):
default=VoiceAssistantAudioSettings(),
converter=VoiceAssistantAudioSettings.from_pb,
)
wake_word_phrase: str = ""
class LogLevel(APIIntEnum):

View File

@ -2047,9 +2047,12 @@ async def test_subscribe_voice_assistant(
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 42
async def handle_stop() -> None:
@ -2068,6 +2071,7 @@ async def test_subscribe_voice_assistant(
start=True,
flags=42,
audio_settings=audio_settings,
wake_word_phrase="okay nabu",
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
@ -2081,6 +2085,7 @@ async def test_subscribe_voice_assistant(
auto_gain=42,
volume_multiplier=42,
),
"okay nabu",
)
]
assert stops == []
@ -2117,9 +2122,12 @@ async def test_subscribe_voice_assistant_failure(
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
# Return None to indicate failure
return None
@ -2152,6 +2160,7 @@ async def test_subscribe_voice_assistant_failure(
auto_gain=42,
volume_multiplier=42,
),
None,
)
]
assert stops == []
@ -2188,9 +2197,12 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
stops = []
async def handle_start(
conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
starts.append((conversation_id, flags, audio_settings))
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
await asyncio.sleep(10)
# Return None to indicate failure
starts.append("never")
@ -2228,6 +2240,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
auto_gain=42,
volume_multiplier=42,
),
None,
)
]