Continuous voice_assistant and silence detection (#4892)

This commit is contained in:
Jesse Hills 2023-05-31 16:30:53 +12:00 committed by GitHub
parent f9f335e692
commit 1ea5d90ea3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 176 additions and 25 deletions

View File

@ -1397,6 +1397,7 @@ message VoiceAssistantRequest {
option (ifdef) = "USE_VOICE_ASSISTANT"; option (ifdef) = "USE_VOICE_ASSISTANT";
bool start = 1; bool start = 1;
string conversation_id = 2;
} }
message VoiceAssistantResponse { message VoiceAssistantResponse {

View File

@ -895,11 +895,12 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_
#endif #endif
#ifdef USE_VOICE_ASSISTANT #ifdef USE_VOICE_ASSISTANT
bool APIConnection::request_voice_assistant(bool start) { bool APIConnection::request_voice_assistant(bool start, const std::string &conversation_id) {
if (!this->voice_assistant_subscription_) if (!this->voice_assistant_subscription_)
return false; return false;
VoiceAssistantRequest msg; VoiceAssistantRequest msg;
msg.start = start; msg.start = start;
msg.conversation_id = conversation_id;
return this->send_voice_assistant_request(msg); return this->send_voice_assistant_request(msg);
} }
void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) { void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) {

View File

@ -128,7 +128,7 @@ class APIConnection : public APIServerConnection {
void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override { void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override {
this->voice_assistant_subscription_ = msg.subscribe; this->voice_assistant_subscription_ = msg.subscribe;
} }
bool request_voice_assistant(bool start); bool request_voice_assistant(bool start, const std::string &conversation_id);
void on_voice_assistant_response(const VoiceAssistantResponse &msg) override; void on_voice_assistant_response(const VoiceAssistantResponse &msg) override;
void on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) override; void on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) override;
#endif #endif

View File

@ -6187,7 +6187,20 @@ bool VoiceAssistantRequest::decode_varint(uint32_t field_id, ProtoVarInt value)
return false; return false;
} }
} }
void VoiceAssistantRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->start); } bool VoiceAssistantRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
switch (field_id) {
case 2: {
this->conversation_id = value.as_string();
return true;
}
default:
return false;
}
}
void VoiceAssistantRequest::encode(ProtoWriteBuffer buffer) const {
buffer.encode_bool(1, this->start);
buffer.encode_string(2, this->conversation_id);
}
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
void VoiceAssistantRequest::dump_to(std::string &out) const { void VoiceAssistantRequest::dump_to(std::string &out) const {
__attribute__((unused)) char buffer[64]; __attribute__((unused)) char buffer[64];
@ -6195,6 +6208,10 @@ void VoiceAssistantRequest::dump_to(std::string &out) const {
out.append(" start: "); out.append(" start: ");
out.append(YESNO(this->start)); out.append(YESNO(this->start));
out.append("\n"); out.append("\n");
out.append(" conversation_id: ");
out.append("'").append(this->conversation_id).append("'");
out.append("\n");
out.append("}"); out.append("}");
} }
#endif #endif

View File

@ -1604,12 +1604,14 @@ class SubscribeVoiceAssistantRequest : public ProtoMessage {
class VoiceAssistantRequest : public ProtoMessage { class VoiceAssistantRequest : public ProtoMessage {
public: public:
bool start{false}; bool start{false};
std::string conversation_id{};
void encode(ProtoWriteBuffer buffer) const override; void encode(ProtoWriteBuffer buffer) const override;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override; void dump_to(std::string &out) const override;
#endif #endif
protected: protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
}; };
class VoiceAssistantResponse : public ProtoMessage { class VoiceAssistantResponse : public ProtoMessage {

View File

@ -428,16 +428,16 @@ void APIServer::on_shutdown() {
} }
#ifdef USE_VOICE_ASSISTANT #ifdef USE_VOICE_ASSISTANT
bool APIServer::start_voice_assistant() { bool APIServer::start_voice_assistant(const std::string &conversation_id) {
for (auto &c : this->clients_) { for (auto &c : this->clients_) {
if (c->request_voice_assistant(true)) if (c->request_voice_assistant(true, conversation_id))
return true; return true;
} }
return false; return false;
} }
void APIServer::stop_voice_assistant() { void APIServer::stop_voice_assistant() {
for (auto &c : this->clients_) { for (auto &c : this->clients_) {
if (c->request_voice_assistant(false)) if (c->request_voice_assistant(false, ""))
return; return;
} }
} }

View File

@ -96,7 +96,7 @@ class APIServer : public Component, public Controller {
#endif #endif
#ifdef USE_VOICE_ASSISTANT #ifdef USE_VOICE_ASSISTANT
bool start_voice_assistant(); bool start_voice_assistant(const std::string &conversation_id);
void stop_voice_assistant(); void stop_voice_assistant();
#endif #endif

View File

@ -62,7 +62,7 @@ BASE_SCHEMA = microphone.MICROPHONE_SCHEMA.extend(
cv.GenerateID(): cv.declare_id(I2SAudioMicrophone), cv.GenerateID(): cv.declare_id(I2SAudioMicrophone),
cv.GenerateID(CONF_I2S_AUDIO_ID): cv.use_id(I2SAudioComponent), cv.GenerateID(CONF_I2S_AUDIO_ID): cv.use_id(I2SAudioComponent),
cv.Optional(CONF_CHANNEL, default="right"): cv.enum(CHANNELS), cv.Optional(CONF_CHANNEL, default="right"): cv.enum(CHANNELS),
cv.Optional(CONF_BITS_PER_SAMPLE, default="16bit"): cv.All( cv.Optional(CONF_BITS_PER_SAMPLE, default="32bit"): cv.All(
_validate_bits, cv.enum(BITS_PER_SAMPLE) _validate_bits, cv.enum(BITS_PER_SAMPLE)
), ),
} }

View File

@ -1,16 +1,23 @@
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.codegen as cg import esphome.codegen as cg
from esphome.const import CONF_ID, CONF_MICROPHONE, CONF_SPEAKER from esphome.const import (
CONF_ID,
CONF_MICROPHONE,
CONF_SPEAKER,
CONF_MEDIA_PLAYER,
)
from esphome import automation from esphome import automation
from esphome.automation import register_action from esphome.automation import register_action, register_condition
from esphome.components import microphone, speaker from esphome.components import microphone, speaker, media_player
AUTO_LOAD = ["socket"] AUTO_LOAD = ["socket"]
DEPENDENCIES = ["api", "microphone"] DEPENDENCIES = ["api", "microphone"]
CODEOWNERS = ["@jesserockz"] CODEOWNERS = ["@jesserockz"]
CONF_SILENCE_DETECTION = "silence_detection"
CONF_ON_LISTENING = "on_listening"
CONF_ON_START = "on_start" CONF_ON_START = "on_start"
CONF_ON_STT_END = "on_stt_end" CONF_ON_STT_END = "on_stt_end"
CONF_ON_TTS_START = "on_tts_start" CONF_ON_TTS_START = "on_tts_start"
@ -25,16 +32,25 @@ VoiceAssistant = voice_assistant_ns.class_("VoiceAssistant", cg.Component)
StartAction = voice_assistant_ns.class_( StartAction = voice_assistant_ns.class_(
"StartAction", automation.Action, cg.Parented.template(VoiceAssistant) "StartAction", automation.Action, cg.Parented.template(VoiceAssistant)
) )
StartContinuousAction = voice_assistant_ns.class_(
"StartContinuousAction", automation.Action, cg.Parented.template(VoiceAssistant)
)
StopAction = voice_assistant_ns.class_( StopAction = voice_assistant_ns.class_(
"StopAction", automation.Action, cg.Parented.template(VoiceAssistant) "StopAction", automation.Action, cg.Parented.template(VoiceAssistant)
) )
IsRunningCondition = voice_assistant_ns.class_(
"IsRunningCondition", automation.Condition, cg.Parented.template(VoiceAssistant)
)
CONFIG_SCHEMA = cv.Schema( CONFIG_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(VoiceAssistant), cv.GenerateID(): cv.declare_id(VoiceAssistant),
cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone), cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone),
cv.Optional(CONF_SPEAKER): cv.use_id(speaker.Speaker), cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker),
cv.Exclusive(CONF_MEDIA_PLAYER, "output"): cv.use_id(media_player.MediaPlayer),
cv.Optional(CONF_SILENCE_DETECTION, default=True): cv.boolean,
cv.Optional(CONF_ON_LISTENING): automation.validate_automation(single=True),
cv.Optional(CONF_ON_START): automation.validate_automation(single=True), cv.Optional(CONF_ON_START): automation.validate_automation(single=True),
cv.Optional(CONF_ON_STT_END): automation.validate_automation(single=True), cv.Optional(CONF_ON_STT_END): automation.validate_automation(single=True),
cv.Optional(CONF_ON_TTS_START): automation.validate_automation(single=True), cv.Optional(CONF_ON_TTS_START): automation.validate_automation(single=True),
@ -56,6 +72,17 @@ async def to_code(config):
spkr = await cg.get_variable(config[CONF_SPEAKER]) spkr = await cg.get_variable(config[CONF_SPEAKER])
cg.add(var.set_speaker(spkr)) cg.add(var.set_speaker(spkr))
if CONF_MEDIA_PLAYER in config:
mp = await cg.get_variable(config[CONF_MEDIA_PLAYER])
cg.add(var.set_media_player(mp))
cg.add(var.set_silence_detection(config[CONF_SILENCE_DETECTION]))
if CONF_ON_LISTENING in config:
await automation.build_automation(
var.get_listening_trigger(), [], config[CONF_ON_LISTENING]
)
if CONF_ON_START in config: if CONF_ON_START in config:
await automation.build_automation( await automation.build_automation(
var.get_start_trigger(), [], config[CONF_ON_START] var.get_start_trigger(), [], config[CONF_ON_START]
@ -96,6 +123,11 @@ async def to_code(config):
VOICE_ASSISTANT_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(VoiceAssistant)}) VOICE_ASSISTANT_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(VoiceAssistant)})
@register_action(
"voice_assistant.start_continuous",
StartContinuousAction,
VOICE_ASSISTANT_ACTION_SCHEMA,
)
@register_action("voice_assistant.start", StartAction, VOICE_ASSISTANT_ACTION_SCHEMA) @register_action("voice_assistant.start", StartAction, VOICE_ASSISTANT_ACTION_SCHEMA)
async def voice_assistant_listen_to_code(config, action_id, template_arg, args): async def voice_assistant_listen_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg) var = cg.new_Pvariable(action_id, template_arg)
@ -108,3 +140,12 @@ async def voice_assistant_stop_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg) var = cg.new_Pvariable(action_id, template_arg)
await cg.register_parented(var, config[CONF_ID]) await cg.register_parented(var, config[CONF_ID])
return var return var
@register_condition(
"voice_assistant.is_running", IsRunningCondition, VOICE_ASSISTANT_ACTION_SCHEMA
)
async def voice_assistant_is_running_to_code(config, condition_id, template_arg, args):
var = cg.new_Pvariable(condition_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var

View File

@ -69,17 +69,42 @@ void VoiceAssistant::setup() {
void VoiceAssistant::loop() { void VoiceAssistant::loop() {
#ifdef USE_SPEAKER #ifdef USE_SPEAKER
if (this->speaker_ == nullptr) { if (this->speaker_ != nullptr) {
return;
}
uint8_t buf[1024]; uint8_t buf[1024];
auto len = this->socket_->read(buf, sizeof(buf)); auto len = this->socket_->read(buf, sizeof(buf));
if (len == -1) { if (len == -1) {
return; return;
} }
this->speaker_->play(buf, len); this->speaker_->play(buf, len);
this->set_timeout("data-incoming", 200, [this]() {
if (this->continuous_) {
this->request_start(true);
}
});
return;
}
#endif #endif
#ifdef USE_MEDIA_PLAYER
if (this->media_player_ != nullptr) {
if (!this->playing_tts_ ||
this->media_player_->state == media_player::MediaPlayerState::MEDIA_PLAYER_STATE_PLAYING) {
return;
}
this->set_timeout("playing-media", 1000, [this]() {
this->playing_tts_ = false;
if (this->continuous_) {
this->request_start(true);
}
});
return;
}
#endif
// Set a 1 second timeout to start the voice assistant again.
this->set_timeout("continuous-no-sound", 1000, [this]() {
if (this->continuous_) {
this->request_start(true);
}
});
} }
void VoiceAssistant::start(struct sockaddr_storage *addr, uint16_t port) { void VoiceAssistant::start(struct sockaddr_storage *addr, uint16_t port) {
@ -100,14 +125,19 @@ void VoiceAssistant::start(struct sockaddr_storage *addr, uint16_t port) {
} }
this->running_ = true; this->running_ = true;
this->mic_->start(); this->mic_->start();
this->listening_trigger_->trigger();
} }
void VoiceAssistant::request_start() { void VoiceAssistant::request_start(bool continuous) {
ESP_LOGD(TAG, "Requesting start..."); ESP_LOGD(TAG, "Requesting start...");
if (!api::global_api_server->start_voice_assistant()) { if (!api::global_api_server->start_voice_assistant(this->conversation_id_)) {
ESP_LOGW(TAG, "Could not request start."); ESP_LOGW(TAG, "Could not request start.");
this->error_trigger_->trigger("not-connected", "Could not request start."); this->error_trigger_->trigger("not-connected", "Could not request start.");
this->continuous_ = false;
return;
} }
this->continuous_ = continuous;
this->set_timeout("reset-conversation_id", 5 * 60 * 1000, [this]() { this->conversation_id_ = ""; });
} }
void VoiceAssistant::signal_stop() { void VoiceAssistant::signal_stop() {
@ -136,9 +166,18 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return; return;
} }
ESP_LOGD(TAG, "Speech recognised as: \"%s\"", text.c_str()); ESP_LOGD(TAG, "Speech recognised as: \"%s\"", text.c_str());
this->signal_stop();
this->stt_end_trigger_->trigger(text); this->stt_end_trigger_->trigger(text);
break; break;
} }
case api::enums::VOICE_ASSISTANT_INTENT_END: {
for (auto arg : msg.data) {
if (arg.name == "conversation_id") {
this->conversation_id_ = std::move(arg.value);
}
}
break;
}
case api::enums::VOICE_ASSISTANT_TTS_START: { case api::enums::VOICE_ASSISTANT_TTS_START: {
std::string text; std::string text;
for (auto arg : msg.data) { for (auto arg : msg.data) {
@ -166,6 +205,12 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return; return;
} }
ESP_LOGD(TAG, "Response URL: \"%s\"", url.c_str()); ESP_LOGD(TAG, "Response URL: \"%s\"", url.c_str());
#ifdef USE_MEDIA_PLAYER
if (this->media_player_ != nullptr) {
this->playing_tts_ = true;
this->media_player_->make_call().set_media_url(url).perform();
}
#endif
this->tts_end_trigger_->trigger(url); this->tts_end_trigger_->trigger(url);
break; break;
} }
@ -184,6 +229,8 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
} }
} }
ESP_LOGE(TAG, "Error: %s - %s", code.c_str(), message.c_str()); ESP_LOGE(TAG, "Error: %s - %s", code.c_str(), message.c_str());
this->continuous_ = false;
this->signal_stop();
this->error_trigger_->trigger(code, message); this->error_trigger_->trigger(code, message);
} }
default: default:

View File

@ -15,6 +15,9 @@
#ifdef USE_SPEAKER #ifdef USE_SPEAKER
#include "esphome/components/speaker/speaker.h" #include "esphome/components/speaker/speaker.h"
#endif #endif
#ifdef USE_MEDIA_PLAYER
#include "esphome/components/media_player/media_player.h"
#endif
#include "esphome/components/socket/socket.h" #include "esphome/components/socket/socket.h"
namespace esphome { namespace esphome {
@ -22,8 +25,10 @@ namespace voice_assistant {
// Version 1: Initial version // Version 1: Initial version
// Version 2: Adds raw speaker support // Version 2: Adds raw speaker support
// Version 3: Adds continuous support
static const uint32_t INITIAL_VERSION = 1; static const uint32_t INITIAL_VERSION = 1;
static const uint32_t SPEAKER_SUPPORT = 2; static const uint32_t SPEAKER_SUPPORT = 2;
static const uint32_t SILENCE_DETECTION_SUPPORT = 3;
class VoiceAssistant : public Component { class VoiceAssistant : public Component {
public: public:
@ -36,20 +41,34 @@ class VoiceAssistant : public Component {
#ifdef USE_SPEAKER #ifdef USE_SPEAKER
void set_speaker(speaker::Speaker *speaker) { this->speaker_ = speaker; } void set_speaker(speaker::Speaker *speaker) { this->speaker_ = speaker; }
#endif #endif
#ifdef USE_MEDIA_PLAYER
void set_media_player(media_player::MediaPlayer *media_player) { this->media_player_ = media_player; }
#endif
uint32_t get_version() const { uint32_t get_version() const {
#ifdef USE_SPEAKER #ifdef USE_SPEAKER
if (this->speaker_ != nullptr) if (this->speaker_ != nullptr) {
if (this->silence_detection_) {
return SILENCE_DETECTION_SUPPORT;
}
return SPEAKER_SUPPORT; return SPEAKER_SUPPORT;
}
#endif #endif
return INITIAL_VERSION; return INITIAL_VERSION;
} }
void request_start(); void request_start(bool continuous = false);
void signal_stop(); void signal_stop();
void on_event(const api::VoiceAssistantEventResponse &msg); void on_event(const api::VoiceAssistantEventResponse &msg);
bool is_running() const { return this->running_; }
void set_continuous(bool continuous) { this->continuous_ = continuous; }
bool is_continuous() const { return this->continuous_; }
void set_silence_detection(bool silence_detection) { this->silence_detection_ = silence_detection; }
Trigger<> *get_listening_trigger() const { return this->listening_trigger_; }
Trigger<> *get_start_trigger() const { return this->start_trigger_; } Trigger<> *get_start_trigger() const { return this->start_trigger_; }
Trigger<std::string> *get_stt_end_trigger() const { return this->stt_end_trigger_; } Trigger<std::string> *get_stt_end_trigger() const { return this->stt_end_trigger_; }
Trigger<std::string> *get_tts_start_trigger() const { return this->tts_start_trigger_; } Trigger<std::string> *get_tts_start_trigger() const { return this->tts_start_trigger_; }
@ -61,6 +80,7 @@ class VoiceAssistant : public Component {
std::unique_ptr<socket::Socket> socket_ = nullptr; std::unique_ptr<socket::Socket> socket_ = nullptr;
struct sockaddr_storage dest_addr_; struct sockaddr_storage dest_addr_;
Trigger<> *listening_trigger_ = new Trigger<>();
Trigger<> *start_trigger_ = new Trigger<>(); Trigger<> *start_trigger_ = new Trigger<>();
Trigger<std::string> *stt_end_trigger_ = new Trigger<std::string>(); Trigger<std::string> *stt_end_trigger_ = new Trigger<std::string>();
Trigger<std::string> *tts_start_trigger_ = new Trigger<std::string>(); Trigger<std::string> *tts_start_trigger_ = new Trigger<std::string>();
@ -72,8 +92,16 @@ class VoiceAssistant : public Component {
#ifdef USE_SPEAKER #ifdef USE_SPEAKER
speaker::Speaker *speaker_{nullptr}; speaker::Speaker *speaker_{nullptr};
#endif #endif
#ifdef USE_MEDIA_PLAYER
media_player::MediaPlayer *media_player_{nullptr};
bool playing_tts_{false};
#endif
std::string conversation_id_{""};
bool running_{false}; bool running_{false};
bool continuous_{false};
bool silence_detection_;
}; };
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<VoiceAssistant> { template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<VoiceAssistant> {
@ -81,9 +109,22 @@ template<typename... Ts> class StartAction : public Action<Ts...>, public Parent
void play(Ts... x) override { this->parent_->request_start(); } void play(Ts... x) override { this->parent_->request_start(); }
}; };
template<typename... Ts> class StartContinuousAction : public Action<Ts...>, public Parented<VoiceAssistant> {
public:
void play(Ts... x) override { this->parent_->request_start(true); }
};
template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<VoiceAssistant> { template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<VoiceAssistant> {
public: public:
void play(Ts... x) override { this->parent_->signal_stop(); } void play(Ts... x) override {
this->parent_->set_continuous(false);
this->parent_->signal_stop();
}
};
template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<VoiceAssistant> {
public:
bool check(Ts... x) override { return this->parent_->is_running() || this->parent_->is_continuous(); }
}; };
extern VoiceAssistant *global_voice_assistant; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) extern VoiceAssistant *global_voice_assistant; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -399,6 +399,7 @@ CONF_MAX_VOLTAGE = "max_voltage"
CONF_MDNS = "mdns" CONF_MDNS = "mdns"
CONF_MEASUREMENT_DURATION = "measurement_duration" CONF_MEASUREMENT_DURATION = "measurement_duration"
CONF_MEASUREMENT_SEQUENCE_NUMBER = "measurement_sequence_number" CONF_MEASUREMENT_SEQUENCE_NUMBER = "measurement_sequence_number"
CONF_MEDIA_PLAYER = "media_player"
CONF_MEDIUM = "medium" CONF_MEDIUM = "medium"
CONF_MEMORY_BLOCKS = "memory_blocks" CONF_MEMORY_BLOCKS = "memory_blocks"
CONF_METHOD = "method" CONF_METHOD = "method"