Send/Receive Voice Assistant Audio Messages (#854)

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Jesse Hills 2024-04-08 10:44:10 +12:00 committed by GitHub
parent 15d1949654
commit 27a968df1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 506 additions and 274 deletions

View File

@ -218,7 +218,8 @@ message DeviceInfoResponse {
string friendly_name = 13; string friendly_name = 13;
uint32 voice_assistant_version = 14; uint32 legacy_voice_assistant_version = 14;
uint32 voice_assistant_feature_flags = 17;
string suggested_area = 16; string suggested_area = 16;
} }
@ -1448,12 +1449,18 @@ message BluetoothDeviceClearCacheResponse {
} }
// ==================== VOICE ASSISTANT ==================== // ==================== VOICE ASSISTANT ====================
enum VoiceAssistantSubscribeFlag {
VOICE_ASSISTANT_SUBSCRIBE_NONE = 0;
VOICE_ASSISTANT_SUBSCRIBE_API_AUDIO = 1;
}
message SubscribeVoiceAssistantRequest { message SubscribeVoiceAssistantRequest {
option (id) = 89; option (id) = 89;
option (source) = SOURCE_CLIENT; option (source) = SOURCE_CLIENT;
option (ifdef) = "USE_VOICE_ASSISTANT"; option (ifdef) = "USE_VOICE_ASSISTANT";
bool subscribe = 1; bool subscribe = 1;
uint32 flags = 2;
} }
message VoiceAssistantAudioSettings { message VoiceAssistantAudioSettings {
@ -1515,6 +1522,15 @@ message VoiceAssistantEventResponse {
repeated VoiceAssistantEventData data = 2; repeated VoiceAssistantEventData data = 2;
} }
message VoiceAssistantAudio {
option (id) = 106;
option (source) = SOURCE_BOTH;
option (ifdef) = "USE_VOICE_ASSISTANT";
bytes data = 1;
bool end = 2;
}
// ==================== ALARM CONTROL PANEL ==================== // ==================== ALARM CONTROL PANEL ====================
enum AlarmControlPanelState { enum AlarmControlPanelState {
ALARM_STATE_DISARMED = 0; ALARM_STATE_DISARMED = 0;

File diff suppressed because one or more lines are too long

View File

@ -67,6 +67,7 @@ from .api_pb2 import ( # type: ignore
TextCommandRequest, TextCommandRequest,
TimeCommandRequest, TimeCommandRequest,
UnsubscribeBluetoothLEAdvertisementsRequest, UnsubscribeBluetoothLEAdvertisementsRequest,
VoiceAssistantAudio,
VoiceAssistantEventData, VoiceAssistantEventData,
VoiceAssistantEventResponse, VoiceAssistantEventResponse,
VoiceAssistantRequest, VoiceAssistantRequest,
@ -121,11 +122,13 @@ from .model import (
MediaPlayerCommand, MediaPlayerCommand,
UserService, UserService,
UserServiceArgType, UserServiceArgType,
VoiceAssistantAudioData,
) )
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
from .model import ( from .model import (
VoiceAssistantCommand, VoiceAssistantCommand,
VoiceAssistantEventType, VoiceAssistantEventType,
VoiceAssistantSubscriptionFlag,
message_types_to_names, message_types_to_names,
) )
from .model_conversions import ( from .model_conversions import (
@ -1226,11 +1229,19 @@ class APIClient:
def subscribe_voice_assistant( def subscribe_voice_assistant(
self, self,
*,
handle_start: Callable[ handle_start: Callable[
[str, int, VoiceAssistantAudioSettingsModel, str | None], [str, int, VoiceAssistantAudioSettingsModel, str | None],
Coroutine[Any, Any, int | None], Coroutine[Any, Any, int | None],
], ],
handle_stop: Callable[[], Coroutine[Any, Any, None]], handle_stop: Callable[[], Coroutine[Any, Any, None]],
handle_audio: (
Callable[
[bytes],
Coroutine[Any, Any, None],
]
| None
) = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Subscribes to voice assistant messages from the device. """Subscribes to voice assistant messages from the device.
@ -1276,16 +1287,39 @@ class APIClient:
else: else:
self._create_background_task(handle_stop()) self._create_background_task(handle_stop())
connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True)) remove_callbacks = []
flags = 0
if handle_audio is not None:
flags |= VoiceAssistantSubscriptionFlag.API_AUDIO
remove_callback = connection.add_message_callback( def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
audio = VoiceAssistantAudioData.from_pb(msg)
if audio.end:
self._create_background_task(handle_stop())
else:
self._create_background_task(handle_audio(audio.data))
remove_callbacks.append(
connection.add_message_callback(
_on_voice_assistant_audio, (VoiceAssistantAudio,)
)
)
connection.send_message(
SubscribeVoiceAssistantRequest(subscribe=True, flags=flags)
)
remove_callbacks.append(
connection.add_message_callback(
_on_voice_assistant_request, (VoiceAssistantRequest,) _on_voice_assistant_request, (VoiceAssistantRequest,)
) )
)
def unsub() -> None: def unsub() -> None:
nonlocal start_task nonlocal start_task
if self._connection is not None: if self._connection is not None:
for remove_callback in remove_callbacks:
remove_callback() remove_callback()
self._connection.send_message( self._connection.send_message(
SubscribeVoiceAssistantRequest(subscribe=False) SubscribeVoiceAssistantRequest(subscribe=False)
@ -1316,6 +1350,10 @@ class APIClient:
) )
self._get_connection().send_message(req) self._get_connection().send_message(req)
def send_voice_assistant_audio(self, data: bytes) -> None:
req = VoiceAssistantAudio(data=data)
self._get_connection().send_message(req)
def alarm_control_panel_command( def alarm_control_panel_command(
self, self,
key: int, key: int,

View File

@ -145,7 +145,7 @@ CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
def _make_hello_request(client_info: str) -> HelloRequest: def _make_hello_request(client_info: str) -> HelloRequest:
"""Make a HelloRequest.""" """Make a HelloRequest."""
return HelloRequest( return HelloRequest(
client_info=client_info, api_version_major=1, api_version_minor=9 client_info=client_info, api_version_major=1, api_version_minor=10
) )

View File

@ -107,6 +107,7 @@ from .api_pb2 import ( # type: ignore
TimeCommandRequest, TimeCommandRequest,
TimeStateResponse, TimeStateResponse,
UnsubscribeBluetoothLEAdvertisementsRequest, UnsubscribeBluetoothLEAdvertisementsRequest,
VoiceAssistantAudio,
VoiceAssistantEventResponse, VoiceAssistantEventResponse,
VoiceAssistantRequest, VoiceAssistantRequest,
VoiceAssistantResponse, VoiceAssistantResponse,
@ -366,4 +367,5 @@ MESSAGE_TYPE_TO_PROTO = {
103: ListEntitiesTimeResponse, 103: ListEntitiesTimeResponse,
104: TimeStateResponse, 104: TimeStateResponse,
105: TimeCommandRequest, 105: TimeCommandRequest,
106: VoiceAssistantAudio,
} }

View File

@ -120,6 +120,16 @@ class BluetoothProxySubscriptionFlag(enum.IntFlag):
RAW_ADVERTISEMENTS = 1 << 0 RAW_ADVERTISEMENTS = 1 << 0
class VoiceAssistantFeature(enum.IntFlag):
VOICE_ASSISTANT = 1 << 0
SPEAKER = 1 << 1
API_AUDIO = 1 << 2
class VoiceAssistantSubscriptionFlag(enum.IntFlag):
API_AUDIO = 1 << 2
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class DeviceInfo(APIModelBase): class DeviceInfo(APIModelBase):
uses_password: bool = False uses_password: bool = False
@ -134,7 +144,8 @@ class DeviceInfo(APIModelBase):
project_name: str = "" project_name: str = ""
project_version: str = "" project_version: str = ""
webserver_port: int = 0 webserver_port: int = 0
voice_assistant_version: int = 0 legacy_voice_assistant_version: int = 0
voice_assistant_feature_flags: int = 0
legacy_bluetooth_proxy_version: int = 0 legacy_bluetooth_proxy_version: int = 0
bluetooth_proxy_feature_flags: int = 0 bluetooth_proxy_feature_flags: int = 0
suggested_area: str = "" suggested_area: str = ""
@ -155,6 +166,16 @@ class DeviceInfo(APIModelBase):
return flags return flags
return self.bluetooth_proxy_feature_flags return self.bluetooth_proxy_feature_flags
def voice_assistant_feature_flags_compat(self, api_version: APIVersion) -> int:
if api_version < APIVersion(1, 10):
flags: int = 0
if self.legacy_voice_assistant_version >= 1:
flags |= VoiceAssistantFeature.VOICE_ASSISTANT
if self.legacy_voice_assistant_version == 2:
flags |= VoiceAssistantFeature.SPEAKER
return flags
return self.voice_assistant_feature_flags
class EntityCategory(APIIntEnum): class EntityCategory(APIIntEnum):
NONE = 0 NONE = 0
@ -1152,6 +1173,12 @@ class VoiceAssistantCommand(APIModelBase):
wake_word_phrase: str = "" wake_word_phrase: str = ""
@_frozen_dataclass_decorator
class VoiceAssistantAudioData(APIModelBase):
data: bytes = field(default_factory=bytes) # pylint: disable=invalid-field-call
end: bool = False
class LogLevel(APIIntEnum): class LogLevel(APIIntEnum):
LOG_LEVEL_NONE = 0 LOG_LEVEL_NONE = 0
LOG_LEVEL_ERROR = 1 LOG_LEVEL_ERROR = 1

View File

@ -64,6 +64,7 @@ from aioesphomeapi.api_pb2 import (
SwitchCommandRequest, SwitchCommandRequest,
TextCommandRequest, TextCommandRequest,
TimeCommandRequest, TimeCommandRequest,
VoiceAssistantAudio,
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantEventData, VoiceAssistantEventData,
VoiceAssistantEventResponse, VoiceAssistantEventResponse,
@ -2107,7 +2108,9 @@ async def test_subscribe_voice_assistant(
async def handle_stop() -> None: async def handle_stop() -> None:
stops.append(True) stops.append(True)
unsub = client.subscribe_voice_assistant(handle_start, handle_stop) unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock() send.reset_mock()
audio_settings = VoiceAssistantAudioSettings( audio_settings = VoiceAssistantAudioSettings(
@ -2183,7 +2186,9 @@ async def test_subscribe_voice_assistant_failure(
async def handle_stop() -> None: async def handle_stop() -> None:
stops.append(True) stops.append(True)
unsub = client.subscribe_voice_assistant(handle_start, handle_stop) unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock() send.reset_mock()
audio_settings = VoiceAssistantAudioSettings( audio_settings = VoiceAssistantAudioSettings(
@ -2260,7 +2265,9 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
async def handle_stop() -> None: async def handle_stop() -> None:
stops.append(True) stops.append(True)
unsub = client.subscribe_voice_assistant(handle_start, handle_stop) unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
)
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
send.reset_mock() send.reset_mock()
audio_settings = VoiceAssistantAudioSettings( audio_settings = VoiceAssistantAudioSettings(
@ -2294,6 +2301,111 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
] ]
@pytest.mark.asyncio
async def test_subscribe_voice_assistant_api_audio(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test subscribe_voice_assistant."""
client, connection, transport, protocol = api_client
send = patch_send(client)
starts = []
stops = []
data_received = 0
async def handle_start(
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 0
async def handle_stop() -> None:
stops.append(True)
async def handle_audio(data: bytes) -> None:
nonlocal data_received
data_received += len(data)
unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop, handle_audio=handle_audio
)
send.assert_called_once_with(
SubscribeVoiceAssistantRequest(subscribe=True, flags=4)
)
send.reset_mock()
audio_settings = VoiceAssistantAudioSettings(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
)
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=True,
flags=42,
audio_settings=audio_settings,
wake_word_phrase="okay nabu",
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
await asyncio.sleep(0)
assert starts == [
(
"theone",
42,
VoiceAssistantAudioSettingsModel(
noise_suppression_level=42,
auto_gain=42,
volume_multiplier=42,
),
"okay nabu",
)
]
assert stops == []
send.assert_called_once_with(VoiceAssistantResponse(port=0))
send.reset_mock()
response: message.Message = VoiceAssistantAudio(
data=bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert data_received == 10
response: message.Message = VoiceAssistantAudio(
end=True,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
send.reset_mock()
client.send_voice_assistant_audio(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
send.assert_called_once_with(
VoiceAssistantAudio(data=bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
)
response: message.Message = VoiceAssistantRequest(
conversation_id="theone",
start=False,
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True, True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
send.reset_mock()
await client.disconnect(force=True)
# Ensure abort callback is a no-op after disconnect
# and does not raise
unsub()
assert len(send.mock_calls) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_version_after_connection_closed( async def test_api_version_after_connection_closed(
api_client: tuple[ api_client: tuple[

View File

@ -105,6 +105,7 @@ from aioesphomeapi.model import (
UserService, UserService,
UserServiceArg, UserServiceArg,
UserServiceArgType, UserServiceArgType,
VoiceAssistantFeature,
build_unique_id, build_unique_id,
converter_field, converter_field,
) )
@ -432,6 +433,24 @@ def test_bluetooth_backcompat_for_device_info(
assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 9)) == 42 assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 9)) == 42
# Add va compat test
@pytest.mark.parametrize(
("version", "flags"),
[
(1, VoiceAssistantFeature.VOICE_ASSISTANT),
(2, VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER),
],
)
def test_voice_assistant_backcompat_for_device_info(
version: int, flags: VoiceAssistantFeature
) -> None:
info = DeviceInfo(
legacy_voice_assistant_version=version, voice_assistant_feature_flags=42
)
assert info.voice_assistant_feature_flags_compat(APIVersion(1, 9)) is flags
assert info.voice_assistant_feature_flags_compat(APIVersion(1, 10)) == 42
@pytest.mark.parametrize( @pytest.mark.parametrize(
( (
"legacy_supports_brightness", "legacy_supports_brightness",