Interpret VoiceAssistantCommand(start=False) as abort (#957)

This commit is contained in:
Michael Hansen 2024-09-12 17:00:08 -05:00 committed by GitHub
parent a2a0bbfb4b
commit 0765a8730c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 18 deletions

View File

@ -1281,7 +1281,7 @@ class APIClient:
[str, int, VoiceAssistantAudioSettingsModel, str | None],
Coroutine[Any, Any, int | None],
],
handle_stop: Callable[[], Coroutine[Any, Any, None]],
handle_stop: Callable[[bool], Coroutine[Any, Any, None]],
handle_audio: (
Callable[
[bytes],
@ -1302,7 +1302,7 @@ class APIClient:
handle_start: called when the devices requests a server to send audio data to.
This callback is asynchronous and returns the port number the server is started on.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted.
handle_audio: called when a chunk of audio is sent from the device.
@ -1343,7 +1343,7 @@ class APIClient:
# We hold a reference to the start_task in unsub function
# so we don't need to add it to the background tasks.
else:
self._create_background_task(handle_stop())
self._create_background_task(handle_stop(True))
remove_callbacks = []
flags = 0
@ -1353,7 +1353,7 @@ class APIClient:
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
audio = VoiceAssistantAudioData.from_pb(msg)
if audio.end:
self._create_background_task(handle_stop())
self._create_background_task(handle_stop(False))
else:
self._create_background_task(handle_audio(audio.data))

View File

@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant(
send = patch_send(client)
starts = []
stops = []
aborts = []
async def handle_start(
conversation_id: str,
@ -2234,7 +2235,10 @@ async def test_subscribe_voice_assistant(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 42
async def handle_stop() -> None:
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant(
@ -2269,7 +2273,8 @@ async def test_subscribe_voice_assistant(
"okay nabu",
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=42))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
@ -2278,7 +2283,8 @@ async def test_subscribe_voice_assistant(
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
assert not stops
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure(
send = patch_send(client)
starts = []
stops = []
aborts = []
async def handle_start(
conversation_id: str,
@ -2312,7 +2319,10 @@ async def test_subscribe_voice_assistant_failure(
# Return None to indicate failure
return None
async def handle_stop() -> None:
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant(
@ -2346,7 +2356,8 @@ async def test_subscribe_voice_assistant_failure(
None,
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(error=True))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
@ -2355,7 +2366,8 @@ async def test_subscribe_voice_assistant_failure(
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
assert not stops
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
send = patch_send(client)
starts = []
stops = []
aborts = []
async def handle_start(
conversation_id: str,
@ -2391,7 +2404,10 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
starts.append("never")
return None
async def handle_stop() -> None:
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)
unsub = client.subscribe_voice_assistant(
@ -2416,6 +2432,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
unsub()
await asyncio.sleep(0)
assert not stops
assert not aborts
assert starts == [
(
"theone",
@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio(
send = patch_send(client)
starts = []
stops = []
aborts = []
data_received = 0
async def handle_start(
@ -2452,7 +2470,10 @@ async def test_subscribe_voice_assistant_api_audio(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 0
async def handle_stop() -> None:
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)
async def handle_audio(data: bytes) -> None:
@ -2493,7 +2514,8 @@ async def test_subscribe_voice_assistant_api_audio(
"okay nabu",
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=0))
send.reset_mock()
@ -2523,7 +2545,8 @@ async def test_subscribe_voice_assistant_api_audio(
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True, True]
assert stops == [True]
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))