From 571c0eb82703e29de70f5c82ed018b4dde2d9dc3 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 17 Sep 2024 18:38:39 -0500 Subject: [PATCH] Add voice assistant methods for configuration (#7459) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/api/api.proto | 8 ++-- esphome/components/api/api_connection.cpp | 33 +++++++++++++++ esphome/components/api/api_connection.h | 3 ++ esphome/components/api/api_pb2.cpp | 41 ++++++++----------- esphome/components/api/api_pb2.h | 9 ++-- esphome/components/api/api_pb2_service.cpp | 29 +++++++++++++ esphome/components/api/api_pb2_service.h | 13 ++++++ .../voice_assistant/voice_assistant.h | 16 ++++++++ 8 files changed, 119 insertions(+), 33 deletions(-) diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index 92fb57b711..684540ffa6 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -62,6 +62,8 @@ service APIConnection { rpc unsubscribe_bluetooth_le_advertisements(UnsubscribeBluetoothLEAdvertisementsRequest) returns (void) {} rpc subscribe_voice_assistant(SubscribeVoiceAssistantRequest) returns (void) {} + rpc voice_assistant_get_configuration(VoiceAssistantConfigurationRequest) returns (VoiceAssistantConfigurationResponse) {} + rpc voice_assistant_set_configuration(VoiceAssistantSetConfiguration) returns (void) {} rpc alarm_control_panel_command (AlarmControlPanelCommandRequest) returns (void) {} } @@ -1572,7 +1574,7 @@ message VoiceAssistantAnnounceFinished { } message VoiceAssistantWakeWord { - uint32 id = 1; + string id = 1; string wake_word = 2; repeated string trained_languages = 3; } @@ -1589,7 +1591,7 @@ message VoiceAssistantConfigurationResponse { option (ifdef) = "USE_VOICE_ASSISTANT"; repeated VoiceAssistantWakeWord available_wake_words = 1; - repeated uint32 active_wake_words = 2; + repeated string active_wake_words = 2; uint32 max_active_wake_words = 3; } @@ -1598,7 +1600,7 @@ message VoiceAssistantSetConfiguration { option (source) = SOURCE_CLIENT; option (ifdef) = "USE_VOICE_ASSISTANT"; - repeated uint32 active_wake_words = 1; + repeated string active_wake_words = 1; } // ==================== ALARM CONTROL PANEL ==================== diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 20e9a45314..7ea52e9a9e 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1224,6 +1224,39 @@ void APIConnection::on_voice_assistant_announce_request(const VoiceAssistantAnno } } +VoiceAssistantConfigurationResponse APIConnection::voice_assistant_get_configuration( + const VoiceAssistantConfigurationRequest &msg) { + VoiceAssistantConfigurationResponse resp; + if (voice_assistant::global_voice_assistant != nullptr) { + if (voice_assistant::global_voice_assistant->get_api_connection() != this) { + return resp; + } + + auto &config = voice_assistant::global_voice_assistant->get_configuration(); + for (auto &wake_word : config.available_wake_words) { + VoiceAssistantWakeWord resp_wake_word; + resp_wake_word.id = wake_word.id; + resp_wake_word.wake_word = wake_word.wake_word; + for (const auto &lang : wake_word.trained_languages) { + resp_wake_word.trained_languages.push_back(lang); + } + resp.available_wake_words.push_back(std::move(resp_wake_word)); + } + resp.max_active_wake_words = config.max_active_wake_words; + } + return resp; +} + +void APIConnection::voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) { + if (voice_assistant::global_voice_assistant != nullptr) { + if (voice_assistant::global_voice_assistant->get_api_connection() != this) { + return; + } + + voice_assistant::global_voice_assistant->on_set_configuration(msg.active_wake_words); + } +} + #endif #ifdef USE_ALARM_CONTROL_PANEL diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index e8d66a5e07..f176cf7c56 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -152,6 +152,9 @@ class APIConnection : public APIServerConnection { void on_voice_assistant_audio(const VoiceAssistantAudio &msg) override; void on_voice_assistant_timer_event_response(const VoiceAssistantTimerEventResponse &msg) override; void on_voice_assistant_announce_request(const VoiceAssistantAnnounceRequest &msg) override; + VoiceAssistantConfigurationResponse voice_assistant_get_configuration( + const VoiceAssistantConfigurationRequest &msg) override; + void voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) override; #endif #ifdef USE_ALARM_CONTROL_PANEL diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 791de29511..8df152881c 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -7124,18 +7124,12 @@ void VoiceAssistantAnnounceFinished::dump_to(std::string &out) const { out.append("}"); } #endif -bool VoiceAssistantWakeWord::decode_varint(uint32_t field_id, ProtoVarInt value) { - switch (field_id) { - case 1: { - this->id = value.as_uint32(); - return true; - } - default: - return false; - } -} bool VoiceAssistantWakeWord::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { + case 1: { + this->id = value.as_string(); + return true; + } case 2: { this->wake_word = value.as_string(); return true; @@ -7149,7 +7143,7 @@ bool VoiceAssistantWakeWord::decode_length(uint32_t field_id, ProtoLengthDelimit } } void VoiceAssistantWakeWord::encode(ProtoWriteBuffer buffer) const { - buffer.encode_uint32(1, this->id); + buffer.encode_string(1, this->id); buffer.encode_string(2, this->wake_word); for (auto &it : this->trained_languages) { buffer.encode_string(3, it, true); @@ -7160,8 +7154,7 @@ void VoiceAssistantWakeWord::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; out.append("VoiceAssistantWakeWord {\n"); out.append(" id: "); - sprintf(buffer, "%" PRIu32, this->id); - out.append(buffer); + out.append("'").append(this->id).append("'"); out.append("\n"); out.append(" wake_word: "); @@ -7184,10 +7177,6 @@ void VoiceAssistantConfigurationRequest::dump_to(std::string &out) const { #endif bool VoiceAssistantConfigurationResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { switch (field_id) { - case 2: { - this->active_wake_words.push_back(value.as_uint32()); - return true; - } case 3: { this->max_active_wake_words = value.as_uint32(); return true; @@ -7202,6 +7191,10 @@ bool VoiceAssistantConfigurationResponse::decode_length(uint32_t field_id, Proto this->available_wake_words.push_back(value.as_message()); return true; } + case 2: { + this->active_wake_words.push_back(value.as_string()); + return true; + } default: return false; } @@ -7211,7 +7204,7 @@ void VoiceAssistantConfigurationResponse::encode(ProtoWriteBuffer buffer) const buffer.encode_message(1, it, true); } for (auto &it : this->active_wake_words) { - buffer.encode_uint32(2, it, true); + buffer.encode_string(2, it, true); } buffer.encode_uint32(3, this->max_active_wake_words); } @@ -7227,8 +7220,7 @@ void VoiceAssistantConfigurationResponse::dump_to(std::string &out) const { for (const auto &it : this->active_wake_words) { out.append(" active_wake_words: "); - sprintf(buffer, "%" PRIu32, it); - out.append(buffer); + out.append("'").append(it).append("'"); out.append("\n"); } @@ -7239,10 +7231,10 @@ void VoiceAssistantConfigurationResponse::dump_to(std::string &out) const { out.append("}"); } #endif -bool VoiceAssistantSetConfiguration::decode_varint(uint32_t field_id, ProtoVarInt value) { +bool VoiceAssistantSetConfiguration::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { case 1: { - this->active_wake_words.push_back(value.as_uint32()); + this->active_wake_words.push_back(value.as_string()); return true; } default: @@ -7251,7 +7243,7 @@ bool VoiceAssistantSetConfiguration::decode_varint(uint32_t field_id, ProtoVarIn } void VoiceAssistantSetConfiguration::encode(ProtoWriteBuffer buffer) const { for (auto &it : this->active_wake_words) { - buffer.encode_uint32(1, it, true); + buffer.encode_string(1, it, true); } } #ifdef HAS_PROTO_MESSAGE_DUMP @@ -7260,8 +7252,7 @@ void VoiceAssistantSetConfiguration::dump_to(std::string &out) const { out.append("VoiceAssistantSetConfiguration {\n"); for (const auto &it : this->active_wake_words) { out.append(" active_wake_words: "); - sprintf(buffer, "%" PRIu32, it); - out.append(buffer); + out.append("'").append(it).append("'"); out.append("\n"); } out.append("}"); diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 8631884f94..063c217bf7 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -1851,7 +1851,7 @@ class VoiceAssistantAnnounceFinished : public ProtoMessage { }; class VoiceAssistantWakeWord : public ProtoMessage { public: - uint32_t id{0}; + std::string id{}; std::string wake_word{}; std::vector trained_languages{}; void encode(ProtoWriteBuffer buffer) const override; @@ -1861,7 +1861,6 @@ class VoiceAssistantWakeWord : public ProtoMessage { protected: bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; - bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; class VoiceAssistantConfigurationRequest : public ProtoMessage { public: @@ -1875,7 +1874,7 @@ class VoiceAssistantConfigurationRequest : public ProtoMessage { class VoiceAssistantConfigurationResponse : public ProtoMessage { public: std::vector available_wake_words{}; - std::vector active_wake_words{}; + std::vector active_wake_words{}; uint32_t max_active_wake_words{0}; void encode(ProtoWriteBuffer buffer) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1888,14 +1887,14 @@ class VoiceAssistantConfigurationResponse : public ProtoMessage { }; class VoiceAssistantSetConfiguration : public ProtoMessage { public: - std::vector active_wake_words{}; + std::vector active_wake_words{}; void encode(ProtoWriteBuffer buffer) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif protected: - bool decode_varint(uint32_t field_id, ProtoVarInt value) override; + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; class ListEntitiesAlarmControlPanelResponse : public ProtoMessage { public: diff --git a/esphome/components/api/api_pb2_service.cpp b/esphome/components/api/api_pb2_service.cpp index 454f20d50a..6e11d7169d 100644 --- a/esphome/components/api/api_pb2_service.cpp +++ b/esphome/components/api/api_pb2_service.cpp @@ -1681,6 +1681,35 @@ void APIServerConnection::on_subscribe_voice_assistant_request(const SubscribeVo this->subscribe_voice_assistant(msg); } #endif +#ifdef USE_VOICE_ASSISTANT +void APIServerConnection::on_voice_assistant_configuration_request(const VoiceAssistantConfigurationRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + VoiceAssistantConfigurationResponse ret = this->voice_assistant_get_configuration(msg); + if (!this->send_voice_assistant_configuration_response(ret)) { + this->on_fatal_error(); + } +} +#endif +#ifdef USE_VOICE_ASSISTANT +void APIServerConnection::on_voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->voice_assistant_set_configuration(msg); +} +#endif #ifdef USE_ALARM_CONTROL_PANEL void APIServerConnection::on_alarm_control_panel_command_request(const AlarmControlPanelCommandRequest &msg) { if (!this->is_connection_setup()) { diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index e69954f7ab..51b94bf530 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -434,6 +434,13 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_VOICE_ASSISTANT virtual void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) = 0; #endif +#ifdef USE_VOICE_ASSISTANT + virtual VoiceAssistantConfigurationResponse voice_assistant_get_configuration( + const VoiceAssistantConfigurationRequest &msg) = 0; +#endif +#ifdef USE_VOICE_ASSISTANT + virtual void voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) = 0; +#endif #ifdef USE_ALARM_CONTROL_PANEL virtual void alarm_control_panel_command(const AlarmControlPanelCommandRequest &msg) = 0; #endif @@ -535,6 +542,12 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_VOICE_ASSISTANT void on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) override; #endif +#ifdef USE_VOICE_ASSISTANT + void on_voice_assistant_configuration_request(const VoiceAssistantConfigurationRequest &msg) override; +#endif +#ifdef USE_VOICE_ASSISTANT + void on_voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) override; +#endif #ifdef USE_ALARM_CONTROL_PANEL void on_alarm_control_panel_command_request(const AlarmControlPanelCommandRequest &msg) override; #endif diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index b0a172332f..56ada0e75a 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -77,6 +77,18 @@ struct Timer { } }; +struct WakeWord { + std::string id; + std::string wake_word; + std::vector trained_languages; +}; + +struct Configuration { + std::vector available_wake_words; + std::vector active_wake_words; + uint32_t max_active_wake_words; +}; + class VoiceAssistant : public Component { public: void setup() override; @@ -133,6 +145,8 @@ class VoiceAssistant : public Component { void on_audio(const api::VoiceAssistantAudio &msg); void on_timer_event(const api::VoiceAssistantTimerEventResponse &msg); void on_announce(const api::VoiceAssistantAnnounceRequest &msg); + void on_set_configuration(const std::vector &active_wake_words){}; + const Configuration &get_configuration() { return this->config_; }; bool is_running() const { return this->state_ != State::IDLE; } void set_continuous(bool continuous) { this->continuous_ = continuous; } @@ -279,6 +293,8 @@ class VoiceAssistant : public Component { AudioMode audio_mode_{AUDIO_MODE_UDP}; bool udp_socket_running_{false}; bool start_udp_socket_(); + + Configuration config_{}; }; template class StartAction : public Action, public Parented {