mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-27 17:37:39 +01:00
Interpret VoiceAssistantCommand(start=False) as abort (#957)
This commit is contained in:
parent
a2a0bbfb4b
commit
0765a8730c
@ -1281,7 +1281,7 @@ class APIClient:
|
||||
[str, int, VoiceAssistantAudioSettingsModel, str | None],
|
||||
Coroutine[Any, Any, int | None],
|
||||
],
|
||||
handle_stop: Callable[[], Coroutine[Any, Any, None]],
|
||||
handle_stop: Callable[[bool], Coroutine[Any, Any, None]],
|
||||
handle_audio: (
|
||||
Callable[
|
||||
[bytes],
|
||||
@ -1302,7 +1302,7 @@ class APIClient:
|
||||
handle_start: called when the devices requests a server to send audio data to.
|
||||
This callback is asynchronous and returns the port number the server is started on.
|
||||
|
||||
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed.
|
||||
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted.
|
||||
|
||||
handle_audio: called when a chunk of audio is sent from the device.
|
||||
|
||||
@ -1343,7 +1343,7 @@ class APIClient:
|
||||
# We hold a reference to the start_task in unsub function
|
||||
# so we don't need to add it to the background tasks.
|
||||
else:
|
||||
self._create_background_task(handle_stop())
|
||||
self._create_background_task(handle_stop(True))
|
||||
|
||||
remove_callbacks = []
|
||||
flags = 0
|
||||
@ -1353,7 +1353,7 @@ class APIClient:
|
||||
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
|
||||
audio = VoiceAssistantAudioData.from_pb(msg)
|
||||
if audio.end:
|
||||
self._create_background_task(handle_stop())
|
||||
self._create_background_task(handle_stop(False))
|
||||
else:
|
||||
self._create_background_task(handle_audio(audio.data))
|
||||
|
||||
|
@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant(
|
||||
send = patch_send(client)
|
||||
starts = []
|
||||
stops = []
|
||||
aborts = []
|
||||
|
||||
async def handle_start(
|
||||
conversation_id: str,
|
||||
@ -2234,8 +2235,11 @@ async def test_subscribe_voice_assistant(
|
||||
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
||||
return 42
|
||||
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
async def handle_stop(abort: bool) -> None:
|
||||
if abort:
|
||||
aborts.append(True)
|
||||
else:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
@ -2269,7 +2273,8 @@ async def test_subscribe_voice_assistant(
|
||||
"okay nabu",
|
||||
)
|
||||
]
|
||||
assert stops == []
|
||||
assert not stops
|
||||
assert not aborts
|
||||
send.assert_called_once_with(VoiceAssistantResponse(port=42))
|
||||
send.reset_mock()
|
||||
response: message.Message = VoiceAssistantRequest(
|
||||
@ -2278,7 +2283,8 @@ async def test_subscribe_voice_assistant(
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert stops == [True]
|
||||
assert not stops
|
||||
assert aborts == [True]
|
||||
send.reset_mock()
|
||||
unsub()
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||
@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure(
|
||||
send = patch_send(client)
|
||||
starts = []
|
||||
stops = []
|
||||
aborts = []
|
||||
|
||||
async def handle_start(
|
||||
conversation_id: str,
|
||||
@ -2312,8 +2319,11 @@ async def test_subscribe_voice_assistant_failure(
|
||||
# Return None to indicate failure
|
||||
return None
|
||||
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
async def handle_stop(abort: bool) -> None:
|
||||
if abort:
|
||||
aborts.append(True)
|
||||
else:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
@ -2346,7 +2356,8 @@ async def test_subscribe_voice_assistant_failure(
|
||||
None,
|
||||
)
|
||||
]
|
||||
assert stops == []
|
||||
assert not stops
|
||||
assert not aborts
|
||||
send.assert_called_once_with(VoiceAssistantResponse(error=True))
|
||||
send.reset_mock()
|
||||
response: message.Message = VoiceAssistantRequest(
|
||||
@ -2355,7 +2366,8 @@ async def test_subscribe_voice_assistant_failure(
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert stops == [True]
|
||||
assert not stops
|
||||
assert aborts == [True]
|
||||
send.reset_mock()
|
||||
unsub()
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||
@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
||||
send = patch_send(client)
|
||||
starts = []
|
||||
stops = []
|
||||
aborts = []
|
||||
|
||||
async def handle_start(
|
||||
conversation_id: str,
|
||||
@ -2391,8 +2404,11 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
||||
starts.append("never")
|
||||
return None
|
||||
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
async def handle_stop(abort: bool) -> None:
|
||||
if abort:
|
||||
aborts.append(True)
|
||||
else:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
@ -2416,6 +2432,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
||||
unsub()
|
||||
await asyncio.sleep(0)
|
||||
assert not stops
|
||||
assert not aborts
|
||||
assert starts == [
|
||||
(
|
||||
"theone",
|
||||
@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio(
|
||||
send = patch_send(client)
|
||||
starts = []
|
||||
stops = []
|
||||
aborts = []
|
||||
data_received = 0
|
||||
|
||||
async def handle_start(
|
||||
@ -2452,8 +2470,11 @@ async def test_subscribe_voice_assistant_api_audio(
|
||||
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
||||
return 0
|
||||
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
async def handle_stop(abort: bool) -> None:
|
||||
if abort:
|
||||
aborts.append(True)
|
||||
else:
|
||||
stops.append(True)
|
||||
|
||||
async def handle_audio(data: bytes) -> None:
|
||||
nonlocal data_received
|
||||
@ -2493,7 +2514,8 @@ async def test_subscribe_voice_assistant_api_audio(
|
||||
"okay nabu",
|
||||
)
|
||||
]
|
||||
assert stops == []
|
||||
assert not stops
|
||||
assert not aborts
|
||||
send.assert_called_once_with(VoiceAssistantResponse(port=0))
|
||||
send.reset_mock()
|
||||
|
||||
@ -2523,7 +2545,8 @@ async def test_subscribe_voice_assistant_api_audio(
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert stops == [True, True]
|
||||
assert stops == [True]
|
||||
assert aborts == [True]
|
||||
send.reset_mock()
|
||||
unsub()
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||
|
Loading…
Reference in New Issue
Block a user