mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-02 18:38:05 +01:00
Send/Receive Voice Assistant Audio Messages (#854)
Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
15d1949654
commit
27a968df1b
@ -218,7 +218,8 @@ message DeviceInfoResponse {
|
||||
|
||||
string friendly_name = 13;
|
||||
|
||||
uint32 voice_assistant_version = 14;
|
||||
uint32 legacy_voice_assistant_version = 14;
|
||||
uint32 voice_assistant_feature_flags = 17;
|
||||
|
||||
string suggested_area = 16;
|
||||
}
|
||||
@ -1448,12 +1449,18 @@ message BluetoothDeviceClearCacheResponse {
|
||||
}
|
||||
|
||||
// ==================== VOICE ASSISTANT ====================
|
||||
enum VoiceAssistantSubscribeFlag {
|
||||
VOICE_ASSISTANT_SUBSCRIBE_NONE = 0;
|
||||
VOICE_ASSISTANT_SUBSCRIBE_API_AUDIO = 1;
|
||||
}
|
||||
|
||||
message SubscribeVoiceAssistantRequest {
|
||||
option (id) = 89;
|
||||
option (source) = SOURCE_CLIENT;
|
||||
option (ifdef) = "USE_VOICE_ASSISTANT";
|
||||
|
||||
bool subscribe = 1;
|
||||
uint32 flags = 2;
|
||||
}
|
||||
|
||||
message VoiceAssistantAudioSettings {
|
||||
@ -1515,6 +1522,15 @@ message VoiceAssistantEventResponse {
|
||||
repeated VoiceAssistantEventData data = 2;
|
||||
}
|
||||
|
||||
message VoiceAssistantAudio {
|
||||
option (id) = 106;
|
||||
option (source) = SOURCE_BOTH;
|
||||
option (ifdef) = "USE_VOICE_ASSISTANT";
|
||||
|
||||
bytes data = 1;
|
||||
bool end = 2;
|
||||
}
|
||||
|
||||
// ==================== ALARM CONTROL PANEL ====================
|
||||
enum AlarmControlPanelState {
|
||||
ALARM_STATE_DISARMED = 0;
|
||||
|
File diff suppressed because one or more lines are too long
@ -67,6 +67,7 @@ from .api_pb2 import ( # type: ignore
|
||||
TextCommandRequest,
|
||||
TimeCommandRequest,
|
||||
UnsubscribeBluetoothLEAdvertisementsRequest,
|
||||
VoiceAssistantAudio,
|
||||
VoiceAssistantEventData,
|
||||
VoiceAssistantEventResponse,
|
||||
VoiceAssistantRequest,
|
||||
@ -121,11 +122,13 @@ from .model import (
|
||||
MediaPlayerCommand,
|
||||
UserService,
|
||||
UserServiceArgType,
|
||||
VoiceAssistantAudioData,
|
||||
)
|
||||
from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel
|
||||
from .model import (
|
||||
VoiceAssistantCommand,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantSubscriptionFlag,
|
||||
message_types_to_names,
|
||||
)
|
||||
from .model_conversions import (
|
||||
@ -1226,11 +1229,19 @@ class APIClient:
|
||||
|
||||
def subscribe_voice_assistant(
|
||||
self,
|
||||
*,
|
||||
handle_start: Callable[
|
||||
[str, int, VoiceAssistantAudioSettingsModel, str | None],
|
||||
Coroutine[Any, Any, int | None],
|
||||
],
|
||||
handle_stop: Callable[[], Coroutine[Any, Any, None]],
|
||||
handle_audio: (
|
||||
Callable[
|
||||
[bytes],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribes to voice assistant messages from the device.
|
||||
|
||||
@ -1276,16 +1287,39 @@ class APIClient:
|
||||
else:
|
||||
self._create_background_task(handle_stop())
|
||||
|
||||
connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True))
|
||||
remove_callbacks = []
|
||||
flags = 0
|
||||
if handle_audio is not None:
|
||||
flags |= VoiceAssistantSubscriptionFlag.API_AUDIO
|
||||
|
||||
remove_callback = connection.add_message_callback(
|
||||
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
|
||||
audio = VoiceAssistantAudioData.from_pb(msg)
|
||||
if audio.end:
|
||||
self._create_background_task(handle_stop())
|
||||
else:
|
||||
self._create_background_task(handle_audio(audio.data))
|
||||
|
||||
remove_callbacks.append(
|
||||
connection.add_message_callback(
|
||||
_on_voice_assistant_audio, (VoiceAssistantAudio,)
|
||||
)
|
||||
)
|
||||
|
||||
connection.send_message(
|
||||
SubscribeVoiceAssistantRequest(subscribe=True, flags=flags)
|
||||
)
|
||||
|
||||
remove_callbacks.append(
|
||||
connection.add_message_callback(
|
||||
_on_voice_assistant_request, (VoiceAssistantRequest,)
|
||||
)
|
||||
)
|
||||
|
||||
def unsub() -> None:
|
||||
nonlocal start_task
|
||||
|
||||
if self._connection is not None:
|
||||
for remove_callback in remove_callbacks:
|
||||
remove_callback()
|
||||
self._connection.send_message(
|
||||
SubscribeVoiceAssistantRequest(subscribe=False)
|
||||
@ -1316,6 +1350,10 @@ class APIClient:
|
||||
)
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
def send_voice_assistant_audio(self, data: bytes) -> None:
|
||||
req = VoiceAssistantAudio(data=data)
|
||||
self._get_connection().send_message(req)
|
||||
|
||||
def alarm_control_panel_command(
|
||||
self,
|
||||
key: int,
|
||||
|
@ -145,7 +145,7 @@ CONNECTION_STATE_CLOSED = ConnectionState.CLOSED
|
||||
def _make_hello_request(client_info: str) -> HelloRequest:
|
||||
"""Make a HelloRequest."""
|
||||
return HelloRequest(
|
||||
client_info=client_info, api_version_major=1, api_version_minor=9
|
||||
client_info=client_info, api_version_major=1, api_version_minor=10
|
||||
)
|
||||
|
||||
|
||||
|
@ -107,6 +107,7 @@ from .api_pb2 import ( # type: ignore
|
||||
TimeCommandRequest,
|
||||
TimeStateResponse,
|
||||
UnsubscribeBluetoothLEAdvertisementsRequest,
|
||||
VoiceAssistantAudio,
|
||||
VoiceAssistantEventResponse,
|
||||
VoiceAssistantRequest,
|
||||
VoiceAssistantResponse,
|
||||
@ -366,4 +367,5 @@ MESSAGE_TYPE_TO_PROTO = {
|
||||
103: ListEntitiesTimeResponse,
|
||||
104: TimeStateResponse,
|
||||
105: TimeCommandRequest,
|
||||
106: VoiceAssistantAudio,
|
||||
}
|
||||
|
@ -120,6 +120,16 @@ class BluetoothProxySubscriptionFlag(enum.IntFlag):
|
||||
RAW_ADVERTISEMENTS = 1 << 0
|
||||
|
||||
|
||||
class VoiceAssistantFeature(enum.IntFlag):
|
||||
VOICE_ASSISTANT = 1 << 0
|
||||
SPEAKER = 1 << 1
|
||||
API_AUDIO = 1 << 2
|
||||
|
||||
|
||||
class VoiceAssistantSubscriptionFlag(enum.IntFlag):
|
||||
API_AUDIO = 1 << 2
|
||||
|
||||
|
||||
@_frozen_dataclass_decorator
|
||||
class DeviceInfo(APIModelBase):
|
||||
uses_password: bool = False
|
||||
@ -134,7 +144,8 @@ class DeviceInfo(APIModelBase):
|
||||
project_name: str = ""
|
||||
project_version: str = ""
|
||||
webserver_port: int = 0
|
||||
voice_assistant_version: int = 0
|
||||
legacy_voice_assistant_version: int = 0
|
||||
voice_assistant_feature_flags: int = 0
|
||||
legacy_bluetooth_proxy_version: int = 0
|
||||
bluetooth_proxy_feature_flags: int = 0
|
||||
suggested_area: str = ""
|
||||
@ -155,6 +166,16 @@ class DeviceInfo(APIModelBase):
|
||||
return flags
|
||||
return self.bluetooth_proxy_feature_flags
|
||||
|
||||
def voice_assistant_feature_flags_compat(self, api_version: APIVersion) -> int:
|
||||
if api_version < APIVersion(1, 10):
|
||||
flags: int = 0
|
||||
if self.legacy_voice_assistant_version >= 1:
|
||||
flags |= VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
if self.legacy_voice_assistant_version == 2:
|
||||
flags |= VoiceAssistantFeature.SPEAKER
|
||||
return flags
|
||||
return self.voice_assistant_feature_flags
|
||||
|
||||
|
||||
class EntityCategory(APIIntEnum):
|
||||
NONE = 0
|
||||
@ -1152,6 +1173,12 @@ class VoiceAssistantCommand(APIModelBase):
|
||||
wake_word_phrase: str = ""
|
||||
|
||||
|
||||
@_frozen_dataclass_decorator
|
||||
class VoiceAssistantAudioData(APIModelBase):
|
||||
data: bytes = field(default_factory=bytes) # pylint: disable=invalid-field-call
|
||||
end: bool = False
|
||||
|
||||
|
||||
class LogLevel(APIIntEnum):
|
||||
LOG_LEVEL_NONE = 0
|
||||
LOG_LEVEL_ERROR = 1
|
||||
|
@ -64,6 +64,7 @@ from aioesphomeapi.api_pb2 import (
|
||||
SwitchCommandRequest,
|
||||
TextCommandRequest,
|
||||
TimeCommandRequest,
|
||||
VoiceAssistantAudio,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantEventData,
|
||||
VoiceAssistantEventResponse,
|
||||
@ -2107,7 +2108,9 @@ async def test_subscribe_voice_assistant(
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(handle_start, handle_stop)
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
)
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
|
||||
send.reset_mock()
|
||||
audio_settings = VoiceAssistantAudioSettings(
|
||||
@ -2183,7 +2186,9 @@ async def test_subscribe_voice_assistant_failure(
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(handle_start, handle_stop)
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
)
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
|
||||
send.reset_mock()
|
||||
audio_settings = VoiceAssistantAudioSettings(
|
||||
@ -2260,7 +2265,9 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(handle_start, handle_stop)
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop
|
||||
)
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True))
|
||||
send.reset_mock()
|
||||
audio_settings = VoiceAssistantAudioSettings(
|
||||
@ -2294,6 +2301,111 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_voice_assistant_api_audio(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test subscribe_voice_assistant."""
|
||||
client, connection, transport, protocol = api_client
|
||||
send = patch_send(client)
|
||||
starts = []
|
||||
stops = []
|
||||
data_received = 0
|
||||
|
||||
async def handle_start(
|
||||
conversation_id: str,
|
||||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
) -> int | None:
|
||||
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
|
||||
return 0
|
||||
|
||||
async def handle_stop() -> None:
|
||||
stops.append(True)
|
||||
|
||||
async def handle_audio(data: bytes) -> None:
|
||||
nonlocal data_received
|
||||
data_received += len(data)
|
||||
|
||||
unsub = client.subscribe_voice_assistant(
|
||||
handle_start=handle_start, handle_stop=handle_stop, handle_audio=handle_audio
|
||||
)
|
||||
send.assert_called_once_with(
|
||||
SubscribeVoiceAssistantRequest(subscribe=True, flags=4)
|
||||
)
|
||||
send.reset_mock()
|
||||
audio_settings = VoiceAssistantAudioSettings(
|
||||
noise_suppression_level=42,
|
||||
auto_gain=42,
|
||||
volume_multiplier=42,
|
||||
)
|
||||
response: message.Message = VoiceAssistantRequest(
|
||||
conversation_id="theone",
|
||||
start=True,
|
||||
flags=42,
|
||||
audio_settings=audio_settings,
|
||||
wake_word_phrase="okay nabu",
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert starts == [
|
||||
(
|
||||
"theone",
|
||||
42,
|
||||
VoiceAssistantAudioSettingsModel(
|
||||
noise_suppression_level=42,
|
||||
auto_gain=42,
|
||||
volume_multiplier=42,
|
||||
),
|
||||
"okay nabu",
|
||||
)
|
||||
]
|
||||
assert stops == []
|
||||
send.assert_called_once_with(VoiceAssistantResponse(port=0))
|
||||
send.reset_mock()
|
||||
|
||||
response: message.Message = VoiceAssistantAudio(
|
||||
data=bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert data_received == 10
|
||||
|
||||
response: message.Message = VoiceAssistantAudio(
|
||||
end=True,
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert stops == [True]
|
||||
|
||||
send.reset_mock()
|
||||
client.send_voice_assistant_audio(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
|
||||
send.assert_called_once_with(
|
||||
VoiceAssistantAudio(data=bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
|
||||
)
|
||||
|
||||
response: message.Message = VoiceAssistantRequest(
|
||||
conversation_id="theone",
|
||||
start=False,
|
||||
)
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await asyncio.sleep(0)
|
||||
assert stops == [True, True]
|
||||
send.reset_mock()
|
||||
unsub()
|
||||
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
|
||||
send.reset_mock()
|
||||
await client.disconnect(force=True)
|
||||
# Ensure abort callback is a no-op after disconnect
|
||||
# and does not raise
|
||||
unsub()
|
||||
assert len(send.mock_calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_version_after_connection_closed(
|
||||
api_client: tuple[
|
||||
|
@ -105,6 +105,7 @@ from aioesphomeapi.model import (
|
||||
UserService,
|
||||
UserServiceArg,
|
||||
UserServiceArgType,
|
||||
VoiceAssistantFeature,
|
||||
build_unique_id,
|
||||
converter_field,
|
||||
)
|
||||
@ -432,6 +433,24 @@ def test_bluetooth_backcompat_for_device_info(
|
||||
assert info.bluetooth_proxy_feature_flags_compat(APIVersion(1, 9)) == 42
|
||||
|
||||
|
||||
# Add va compat test
|
||||
@pytest.mark.parametrize(
|
||||
("version", "flags"),
|
||||
[
|
||||
(1, VoiceAssistantFeature.VOICE_ASSISTANT),
|
||||
(2, VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER),
|
||||
],
|
||||
)
|
||||
def test_voice_assistant_backcompat_for_device_info(
|
||||
version: int, flags: VoiceAssistantFeature
|
||||
) -> None:
|
||||
info = DeviceInfo(
|
||||
legacy_voice_assistant_version=version, voice_assistant_feature_flags=42
|
||||
)
|
||||
assert info.voice_assistant_feature_flags_compat(APIVersion(1, 9)) is flags
|
||||
assert info.voice_assistant_feature_flags_compat(APIVersion(1, 10)) == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"legacy_supports_brightness",
|
||||
|
Loading…
Reference in New Issue
Block a user