mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-30 18:08:36 +01:00
Interpret VoiceAssistantCommand(start=False) as abort (#957)
This commit is contained in:
parent
a2a0bbfb4b
commit
0765a8730c
@ -1281,7 +1281,7 @@ class APIClient:
|
|||||||
[str, int, VoiceAssistantAudioSettingsModel, str | None],
|
[str, int, VoiceAssistantAudioSettingsModel, str | None],
|
||||||
Coroutine[Any, Any, int | None],
|
Coroutine[Any, Any, int | None],
|
||||||
],
|
],
|
||||||
handle_stop: Callable[[], Coroutine[Any, Any, None]],
|
handle_stop: Callable[[bool], Coroutine[Any, Any, None]],
|
||||||
handle_audio: (
|
handle_audio: (
|
||||||
Callable[
|
Callable[
|
||||||
[bytes],
|
[bytes],
|
||||||
@ -1302,7 +1302,7 @@ class APIClient:
|
|||||||
handle_start: called when the devices requests a server to send audio data to.
|
handle_start: called when the devices requests a server to send audio data to.
|
||||||
This callback is asynchronous and returns the port number the server is started on.
|
This callback is asynchronous and returns the port number the server is started on.
|
||||||
|
|
||||||
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed.
|
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted.
|
||||||
|
|
||||||
handle_audio: called when a chunk of audio is sent from the device.
|
handle_audio: called when a chunk of audio is sent from the device.
|
||||||
|
|
||||||
@ -1343,7 +1343,7 @@ class APIClient:
|
|||||||
# We hold a reference to the start_task in unsub function
|
# We hold a reference to the start_task in unsub function
|
||||||
# so we don't need to add it to the background tasks.
|
# so we don't need to add it to the background tasks.
|
||||||
else:
|
else:
|
||||||
self._create_background_task(handle_stop())
|
self._create_background_task(handle_stop(True))
|
||||||
|
|
||||||
remove_callbacks = []
|
remove_callbacks = []
|
||||||
flags = 0
|
flags = 0
|
||||||
@ -1353,7 +1353,7 @@ class APIClient:
|
|||||||
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
|
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
|
||||||
audio = VoiceAssistantAudioData.from_pb(msg)
|
audio = VoiceAssistantAudioData.from_pb(msg)
|
||||||
if audio.end:
|
if audio.end:
|
||||||
self._create_background_task(handle_stop())
|
self._create_background_task(handle_stop(False))
|
||||||
else:
|
else:
|
||||||
self._create_background_task(handle_audio(audio.data))
|
self._create_background_task(handle_audio(audio.data))
|
||||||
|
|
||||||
|
@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant(
|
|||||||
send = patch_send(client)
|
send = patch_send(client)
|
||||||
starts = []
|
starts = []
|
||||||
stops = []
|
stops = []
|
||||||
|
aborts = []
|
||||||
|
|
||||||
async def handle_start(
|
async def handle_start(
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
@ -2234,7 +2235,10 @@ async def test_subscribe_voice_assistant(
|
|||||||
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
||||||
return 42
|
return 42
|
||||||
|
|
||||||
async def handle_stop() -> None:
|
async def handle_stop(abort: bool) -> None:
|
||||||
|
if abort:
|
||||||
|
aborts.append(True)
|
||||||
|
else:
|
||||||
stops.append(True)
|
stops.append(True)
|
||||||
|
|
||||||
unsub = client.subscribe_voice_assistant(
|
unsub = client.subscribe_voice_assistant(
|
||||||
@ -2269,7 +2273,8 @@ async def test_subscribe_voice_assistant(
|
|||||||
"okay nabu",
|
"okay nabu",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
assert stops == []
|
assert not stops
|
||||||
|
assert not aborts
|
||||||
send.assert_called_once_with(VoiceAssistantResponse(port=42))
|
send.assert_called_once_with(VoiceAssistantResponse(port=42))
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
response: message.Message = VoiceAssistantRequest(
|
response: message.Message = VoiceAssistantRequest(
|
||||||
@ -2278,7 +2283,8 @@ async def test_subscribe_voice_assistant(
|
|||||||
)
|
)
|
||||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert stops == [True]
|
assert not stops
|
||||||
|
assert aborts == [True]
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
unsub()
|
unsub()
|
||||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||||
@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure(
|
|||||||
send = patch_send(client)
|
send = patch_send(client)
|
||||||
starts = []
|
starts = []
|
||||||
stops = []
|
stops = []
|
||||||
|
aborts = []
|
||||||
|
|
||||||
async def handle_start(
|
async def handle_start(
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
@ -2312,7 +2319,10 @@ async def test_subscribe_voice_assistant_failure(
|
|||||||
# Return None to indicate failure
|
# Return None to indicate failure
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def handle_stop() -> None:
|
async def handle_stop(abort: bool) -> None:
|
||||||
|
if abort:
|
||||||
|
aborts.append(True)
|
||||||
|
else:
|
||||||
stops.append(True)
|
stops.append(True)
|
||||||
|
|
||||||
unsub = client.subscribe_voice_assistant(
|
unsub = client.subscribe_voice_assistant(
|
||||||
@ -2346,7 +2356,8 @@ async def test_subscribe_voice_assistant_failure(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
assert stops == []
|
assert not stops
|
||||||
|
assert not aborts
|
||||||
send.assert_called_once_with(VoiceAssistantResponse(error=True))
|
send.assert_called_once_with(VoiceAssistantResponse(error=True))
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
response: message.Message = VoiceAssistantRequest(
|
response: message.Message = VoiceAssistantRequest(
|
||||||
@ -2355,7 +2366,8 @@ async def test_subscribe_voice_assistant_failure(
|
|||||||
)
|
)
|
||||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert stops == [True]
|
assert not stops
|
||||||
|
assert aborts == [True]
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
unsub()
|
unsub()
|
||||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||||
@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
|||||||
send = patch_send(client)
|
send = patch_send(client)
|
||||||
starts = []
|
starts = []
|
||||||
stops = []
|
stops = []
|
||||||
|
aborts = []
|
||||||
|
|
||||||
async def handle_start(
|
async def handle_start(
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
@ -2391,7 +2404,10 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
|||||||
starts.append("never")
|
starts.append("never")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def handle_stop() -> None:
|
async def handle_stop(abort: bool) -> None:
|
||||||
|
if abort:
|
||||||
|
aborts.append(True)
|
||||||
|
else:
|
||||||
stops.append(True)
|
stops.append(True)
|
||||||
|
|
||||||
unsub = client.subscribe_voice_assistant(
|
unsub = client.subscribe_voice_assistant(
|
||||||
@ -2416,6 +2432,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
|||||||
unsub()
|
unsub()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert not stops
|
assert not stops
|
||||||
|
assert not aborts
|
||||||
assert starts == [
|
assert starts == [
|
||||||
(
|
(
|
||||||
"theone",
|
"theone",
|
||||||
@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio(
|
|||||||
send = patch_send(client)
|
send = patch_send(client)
|
||||||
starts = []
|
starts = []
|
||||||
stops = []
|
stops = []
|
||||||
|
aborts = []
|
||||||
data_received = 0
|
data_received = 0
|
||||||
|
|
||||||
async def handle_start(
|
async def handle_start(
|
||||||
@ -2452,7 +2470,10 @@ async def test_subscribe_voice_assistant_api_audio(
|
|||||||
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def handle_stop() -> None:
|
async def handle_stop(abort: bool) -> None:
|
||||||
|
if abort:
|
||||||
|
aborts.append(True)
|
||||||
|
else:
|
||||||
stops.append(True)
|
stops.append(True)
|
||||||
|
|
||||||
async def handle_audio(data: bytes) -> None:
|
async def handle_audio(data: bytes) -> None:
|
||||||
@ -2493,7 +2514,8 @@ async def test_subscribe_voice_assistant_api_audio(
|
|||||||
"okay nabu",
|
"okay nabu",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
assert stops == []
|
assert not stops
|
||||||
|
assert not aborts
|
||||||
send.assert_called_once_with(VoiceAssistantResponse(port=0))
|
send.assert_called_once_with(VoiceAssistantResponse(port=0))
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
|
|
||||||
@ -2523,7 +2545,8 @@ async def test_subscribe_voice_assistant_api_audio(
|
|||||||
)
|
)
|
||||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert stops == [True, True]
|
assert stops == [True]
|
||||||
|
assert aborts == [True]
|
||||||
send.reset_mock()
|
send.reset_mock()
|
||||||
unsub()
|
unsub()
|
||||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||||
|
Loading…
Reference in New Issue
Block a user