Interpret VoiceAssistantCommand(start=False) as abort (#957)

This commit is contained in:
Michael Hansen 2024-09-12 17:00:08 -05:00 committed by GitHub
parent a2a0bbfb4b
commit 0765a8730c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 18 deletions

View File

@ -1281,7 +1281,7 @@ class APIClient:
[str, int, VoiceAssistantAudioSettingsModel, str | None], [str, int, VoiceAssistantAudioSettingsModel, str | None],
Coroutine[Any, Any, int | None], Coroutine[Any, Any, int | None],
], ],
handle_stop: Callable[[], Coroutine[Any, Any, None]], handle_stop: Callable[[bool], Coroutine[Any, Any, None]],
handle_audio: ( handle_audio: (
Callable[ Callable[
[bytes], [bytes],
@ -1302,7 +1302,7 @@ class APIClient:
handle_start: called when the devices requests a server to send audio data to. handle_start: called when the devices requests a server to send audio data to.
This callback is asynchronous and returns the port number the server is started on. This callback is asynchronous and returns the port number the server is started on.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed. handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted.
handle_audio: called when a chunk of audio is sent from the device. handle_audio: called when a chunk of audio is sent from the device.
@ -1343,7 +1343,7 @@ class APIClient:
# We hold a reference to the start_task in unsub function # We hold a reference to the start_task in unsub function
# so we don't need to add it to the background tasks. # so we don't need to add it to the background tasks.
else: else:
self._create_background_task(handle_stop()) self._create_background_task(handle_stop(True))
remove_callbacks = [] remove_callbacks = []
flags = 0 flags = 0
@ -1353,7 +1353,7 @@ class APIClient:
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None: def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
audio = VoiceAssistantAudioData.from_pb(msg) audio = VoiceAssistantAudioData.from_pb(msg)
if audio.end: if audio.end:
self._create_background_task(handle_stop()) self._create_background_task(handle_stop(False))
else: else:
self._create_background_task(handle_audio(audio.data)) self._create_background_task(handle_audio(audio.data))

View File

@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant(
send = patch_send(client) send = patch_send(client)
starts = [] starts = []
stops = [] stops = []
aborts = []
async def handle_start( async def handle_start(
conversation_id: str, conversation_id: str,
@ -2234,8 +2235,11 @@ async def test_subscribe_voice_assistant(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase)) starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 42 return 42
async def handle_stop() -> None: async def handle_stop(abort: bool) -> None:
stops.append(True) if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant( unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop handle_start=handle_start, handle_stop=handle_stop
@ -2269,7 +2273,8 @@ async def test_subscribe_voice_assistant(
"okay nabu", "okay nabu",
) )
] ]
assert stops == [] assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=42)) send.assert_called_once_with(VoiceAssistantResponse(port=42))
send.reset_mock() send.reset_mock()
response: message.Message = VoiceAssistantRequest( response: message.Message = VoiceAssistantRequest(
@ -2278,7 +2283,8 @@ async def test_subscribe_voice_assistant(
) )
mock_data_received(protocol, generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0) await asyncio.sleep(0)
assert stops == [True] assert not stops
assert aborts == [True]
send.reset_mock() send.reset_mock()
unsub() unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure(
send = patch_send(client) send = patch_send(client)
starts = [] starts = []
stops = [] stops = []
aborts = []
async def handle_start( async def handle_start(
conversation_id: str, conversation_id: str,
@ -2312,8 +2319,11 @@ async def test_subscribe_voice_assistant_failure(
# Return None to indicate failure # Return None to indicate failure
return None return None
async def handle_stop() -> None: async def handle_stop(abort: bool) -> None:
stops.append(True) if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant( unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop handle_start=handle_start, handle_stop=handle_stop
@ -2346,7 +2356,8 @@ async def test_subscribe_voice_assistant_failure(
None, None,
) )
] ]
assert stops == [] assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(error=True)) send.assert_called_once_with(VoiceAssistantResponse(error=True))
send.reset_mock() send.reset_mock()
response: message.Message = VoiceAssistantRequest( response: message.Message = VoiceAssistantRequest(
@ -2355,7 +2366,8 @@ async def test_subscribe_voice_assistant_failure(
) )
mock_data_received(protocol, generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0) await asyncio.sleep(0)
assert stops == [True] assert not stops
assert aborts == [True]
send.reset_mock() send.reset_mock()
unsub() unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
send = patch_send(client) send = patch_send(client)
starts = [] starts = []
stops = [] stops = []
aborts = []
async def handle_start( async def handle_start(
conversation_id: str, conversation_id: str,
@ -2391,8 +2404,11 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
starts.append("never") starts.append("never")
return None return None
async def handle_stop() -> None: async def handle_stop(abort: bool) -> None:
stops.append(True) if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant( unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop handle_start=handle_start, handle_stop=handle_stop
@ -2416,6 +2432,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
unsub() unsub()
await asyncio.sleep(0) await asyncio.sleep(0)
assert not stops assert not stops
assert not aborts
assert starts == [ assert starts == [
( (
"theone", "theone",
@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio(
send = patch_send(client) send = patch_send(client)
starts = [] starts = []
stops = [] stops = []
aborts = []
data_received = 0 data_received = 0
async def handle_start( async def handle_start(
@ -2452,8 +2470,11 @@ async def test_subscribe_voice_assistant_api_audio(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase)) starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 0 return 0
async def handle_stop() -> None: async def handle_stop(abort: bool) -> None:
stops.append(True) if abort:
aborts.append(True)
else:
stops.append(True)
async def handle_audio(data: bytes) -> None: async def handle_audio(data: bytes) -> None:
nonlocal data_received nonlocal data_received
@ -2493,7 +2514,8 @@ async def test_subscribe_voice_assistant_api_audio(
"okay nabu", "okay nabu",
) )
] ]
assert stops == [] assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=0)) send.assert_called_once_with(VoiceAssistantResponse(port=0))
send.reset_mock() send.reset_mock()
@ -2523,7 +2545,8 @@ async def test_subscribe_voice_assistant_api_audio(
) )
mock_data_received(protocol, generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0) await asyncio.sleep(0)
assert stops == [True, True] assert stops == [True]
assert aborts == [True]
send.reset_mock() send.reset_mock()
unsub() unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))