From 193bac94f4863fac2ac404b0a44663f6da83e976 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:16:42 +1300 Subject: [PATCH] Add on_client_connected and disconnected to voice assistant (#5629) --- esphome/components/api/__init__.py | 4 +- esphome/components/api/api_connection.cpp | 22 +++++++-- esphome/components/api/api_connection.h | 10 ++-- esphome/components/api/api_server.cpp | 24 ---------- esphome/components/api/api_server.h | 6 --- .../components/voice_assistant/__init__.py | 22 +++++++++ .../voice_assistant/voice_assistant.cpp | 46 ++++++++++++++++--- .../voice_assistant/voice_assistant.h | 13 +++++- esphome/const.py | 2 + 9 files changed, 98 insertions(+), 51 deletions(-) diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index ec1a56bd2c..d6b4416af8 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -18,6 +18,8 @@ from esphome.const import ( CONF_TRIGGER_ID, CONF_EVENT, CONF_TAG, + CONF_ON_CLIENT_CONNECTED, + CONF_ON_CLIENT_DISCONNECTED, ) from esphome.core import coroutine_with_priority @@ -45,8 +47,6 @@ SERVICE_ARG_NATIVE_TYPES = { "string[]": cg.std_vector.template(cg.std_string), } CONF_ENCRYPTION = "encryption" -CONF_ON_CLIENT_CONNECTED = "on_client_connected" -CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected" def validate_encryption_key(value): diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index d1e7513d11..0389df215f 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -60,6 +60,11 @@ APIConnection::~APIConnection() { bluetooth_proxy::global_bluetooth_proxy->unsubscribe_api_connection(this); } #endif +#ifdef USE_VOICE_ASSISTANT + if (voice_assistant::global_voice_assistant->get_api_connection() == this) { + voice_assistant::global_voice_assistant->client_subscription(this, false); + } +#endif } void APIConnection::loop() { @@ -950,14 +955,17 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_ #endif #ifdef USE_VOICE_ASSISTANT -bool APIConnection::request_voice_assistant(const VoiceAssistantRequest &msg) { - if (!this->voice_assistant_subscription_) - return false; - - return this->send_voice_assistant_request(msg); +void APIConnection::subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) { + if (voice_assistant::global_voice_assistant != nullptr) { + voice_assistant::global_voice_assistant->client_subscription(this, msg.subscribe); + } } void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) { if (voice_assistant::global_voice_assistant != nullptr) { + if (voice_assistant::global_voice_assistant->get_api_connection() != this) { + return; + } + if (msg.error) { voice_assistant::global_voice_assistant->failed_to_start(); return; @@ -970,6 +978,10 @@ void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &ms }; void APIConnection::on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) { if (voice_assistant::global_voice_assistant != nullptr) { + if (voice_assistant::global_voice_assistant->get_api_connection() != this) { + return; + } + voice_assistant::global_voice_assistant->on_event(msg); } } diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 21ee85daab..09b595bb71 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -126,10 +126,7 @@ class APIConnection : public APIServerConnection { #endif #ifdef USE_VOICE_ASSISTANT - void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override { - this->voice_assistant_subscription_ = msg.subscribe; - } - bool request_voice_assistant(const VoiceAssistantRequest &msg); + void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override; void on_voice_assistant_response(const VoiceAssistantResponse &msg) override; void on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) override; #endif @@ -188,6 +185,8 @@ class APIConnection : public APIServerConnection { } bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override; + std::string get_client_combined_info() const { return this->client_combined_info_; } + protected: friend APIServer; @@ -220,9 +219,6 @@ class APIConnection : public APIServerConnection { uint32_t last_traffic_; bool sent_ping_{false}; bool service_call_subscription_{false}; -#ifdef USE_VOICE_ASSISTANT - bool voice_assistant_subscription_{false}; -#endif bool next_close_ = false; APIServer *parent_; InitialStateIterator initial_state_iterator_; diff --git a/esphome/components/api/api_server.cpp b/esphome/components/api/api_server.cpp index 5268b30132..0348112fcd 100644 --- a/esphome/components/api/api_server.cpp +++ b/esphome/components/api/api_server.cpp @@ -332,30 +332,6 @@ void APIServer::on_shutdown() { delay(10); } -#ifdef USE_VOICE_ASSISTANT -bool APIServer::start_voice_assistant(const std::string &conversation_id, uint32_t flags, - const api::VoiceAssistantAudioSettings &audio_settings) { - VoiceAssistantRequest msg; - msg.start = true; - msg.conversation_id = conversation_id; - msg.flags = flags; - msg.audio_settings = audio_settings; - for (auto &c : this->clients_) { - if (c->request_voice_assistant(msg)) - return true; - } - return false; -} -void APIServer::stop_voice_assistant() { - VoiceAssistantRequest msg; - msg.start = false; - for (auto &c : this->clients_) { - if (c->request_voice_assistant(msg)) - return; - } -} -#endif - #ifdef USE_ALARM_CONTROL_PANEL void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) { if (obj->is_internal()) diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index f1fb31fa8b..9605a196b3 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -84,12 +84,6 @@ class APIServer : public Component, public Controller { void request_time(); #endif -#ifdef USE_VOICE_ASSISTANT - bool start_voice_assistant(const std::string &conversation_id, uint32_t flags, - const api::VoiceAssistantAudioSettings &audio_settings); - void stop_voice_assistant(); -#endif - #ifdef USE_ALARM_CONTROL_PANEL void on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) override; #endif diff --git a/esphome/components/voice_assistant/__init__.py b/esphome/components/voice_assistant/__init__.py index 14176ad7cf..3270b9f370 100644 --- a/esphome/components/voice_assistant/__init__.py +++ b/esphome/components/voice_assistant/__init__.py @@ -6,6 +6,8 @@ from esphome.const import ( CONF_MICROPHONE, CONF_SPEAKER, CONF_MEDIA_PLAYER, + CONF_ON_CLIENT_CONNECTED, + CONF_ON_CLIENT_DISCONNECTED, ) from esphome import automation from esphome.automation import register_action, register_condition @@ -80,6 +82,12 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_ON_TTS_END): automation.validate_automation(single=True), cv.Optional(CONF_ON_END): automation.validate_automation(single=True), cv.Optional(CONF_ON_ERROR): automation.validate_automation(single=True), + cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation( + single=True + ), + cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation( + single=True + ), } ).extend(cv.COMPONENT_SCHEMA), ) @@ -155,6 +163,20 @@ async def to_code(config): config[CONF_ON_ERROR], ) + if CONF_ON_CLIENT_CONNECTED in config: + await automation.build_automation( + var.get_client_connected_trigger(), + [], + config[CONF_ON_CLIENT_CONNECTED], + ) + + if CONF_ON_CLIENT_DISCONNECTED in config: + await automation.build_automation( + var.get_client_disconnected_trigger(), + [], + config[CONF_ON_CLIENT_DISCONNECTED], + ) + cg.add_define("USE_VOICE_ASSISTANT") diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index df7853156d..d15d702d4b 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -127,8 +127,8 @@ int VoiceAssistant::read_microphone_() { } void VoiceAssistant::loop() { - if (this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && - this->state_ != State::STOPPING_MICROPHONE && !api::global_api_server->is_connected()) { + if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && + this->state_ != State::STOPPING_MICROPHONE) { if (this->mic_->is_running() || this->state_ == State::STARTING_MICROPHONE) { this->set_state_(State::STOP_MICROPHONE, State::IDLE); } else { @@ -213,7 +213,14 @@ void VoiceAssistant::loop() { audio_settings.noise_suppression_level = this->noise_suppression_level_; audio_settings.auto_gain = this->auto_gain_; audio_settings.volume_multiplier = this->volume_multiplier_; - if (!api::global_api_server->start_voice_assistant(this->conversation_id_, flags, audio_settings)) { + + api::VoiceAssistantRequest msg; + msg.start = true; + msg.conversation_id = this->conversation_id_; + msg.flags = flags; + msg.audio_settings = audio_settings; + + if (this->api_client_ == nullptr || !this->api_client_->send_voice_assistant_request(msg)) { ESP_LOGW(TAG, "Could not request start."); this->error_trigger_->trigger("not-connected", "Could not request start."); this->continuous_ = false; @@ -326,6 +333,28 @@ void VoiceAssistant::loop() { } } +void VoiceAssistant::client_subscription(api::APIConnection *client, bool subscribe) { + if (!subscribe) { + if (this->api_client_ == nullptr || client != this->api_client_) { + ESP_LOGE(TAG, "Client attempting to unsubscribe that is not the current API Client"); + return; + } + this->api_client_ = nullptr; + this->client_disconnected_trigger_->trigger(); + return; + } + + if (this->api_client_ != nullptr) { + ESP_LOGE(TAG, "Multiple API Clients attempting to connect to Voice Assistant"); + ESP_LOGE(TAG, "Current client: %s", this->api_client_->get_client_combined_info().c_str()); + ESP_LOGE(TAG, "New client: %s", client->get_client_combined_info().c_str()); + return; + } + + this->api_client_ = client; + this->client_connected_trigger_->trigger(); +} + static const LogString *voice_assistant_state_to_string(State state) { switch (state) { case State::IDLE: @@ -408,7 +437,7 @@ void VoiceAssistant::start_streaming(struct sockaddr_storage *addr, uint16_t por } void VoiceAssistant::request_start(bool continuous, bool silence_detection) { - if (!api::global_api_server->is_connected()) { + if (this->api_client_ == nullptr) { ESP_LOGE(TAG, "No API client connected"); this->set_state_(State::IDLE, State::IDLE); this->continuous_ = false; @@ -459,9 +488,14 @@ void VoiceAssistant::request_stop() { } void VoiceAssistant::signal_stop_() { - ESP_LOGD(TAG, "Signaling stop..."); - api::global_api_server->stop_voice_assistant(); memset(&this->dest_addr_, 0, sizeof(this->dest_addr_)); + if (this->api_client_ == nullptr) { + return; + } + ESP_LOGD(TAG, "Signaling stop..."); + api::VoiceAssistantRequest msg; + msg.start = false; + this->api_client_->send_voice_assistant_request(msg); } void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index cd448293db..a265522bca 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -8,8 +8,8 @@ #include "esphome/core/component.h" #include "esphome/core/helpers.h" +#include "esphome/components/api/api_connection.h" #include "esphome/components/api/api_pb2.h" -#include "esphome/components/api/api_server.h" #include "esphome/components/microphone/microphone.h" #ifdef USE_SPEAKER #include "esphome/components/speaker/speaker.h" @@ -109,6 +109,12 @@ class VoiceAssistant : public Component { Trigger<> *get_end_trigger() const { return this->end_trigger_; } Trigger *get_error_trigger() const { return this->error_trigger_; } + Trigger<> *get_client_connected_trigger() const { return this->client_connected_trigger_; } + Trigger<> *get_client_disconnected_trigger() const { return this->client_disconnected_trigger_; } + + void client_subscription(api::APIConnection *client, bool subscribe); + api::APIConnection *get_api_connection() const { return this->api_client_; } + protected: int read_microphone_(); void set_state_(State state); @@ -127,6 +133,11 @@ class VoiceAssistant : public Component { Trigger<> *end_trigger_ = new Trigger<>(); Trigger *error_trigger_ = new Trigger(); + Trigger<> *client_connected_trigger_ = new Trigger<>(); + Trigger<> *client_disconnected_trigger_ = new Trigger<>(); + + api::APIConnection *api_client_{nullptr}; + microphone::Microphone *mic_{nullptr}; #ifdef USE_SPEAKER speaker::Speaker *speaker_{nullptr}; diff --git a/esphome/const.py b/esphome/const.py index 6dde15303a..9457958863 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -485,6 +485,8 @@ CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE = "on_ble_manufacturer_data_advertise" CONF_ON_BLE_SERVICE_DATA_ADVERTISE = "on_ble_service_data_advertise" CONF_ON_BOOT = "on_boot" CONF_ON_CLICK = "on_click" +CONF_ON_CLIENT_CONNECTED = "on_client_connected" +CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected" CONF_ON_CONNECT = "on_connect" CONF_ON_CONTROL = "on_control" CONF_ON_DISCONNECT = "on_disconnect"