This commit is contained in:
Kevin Ahrendt 2024-05-02 15:45:58 +12:00 committed by GitHub
commit 74193da912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 736 additions and 351 deletions

View File

@ -30,6 +30,7 @@ from esphome.const import (
CONF_USERNAME,
CONF_PASSWORD,
CONF_RAW_DATA_ID,
CONF_THRESHOLD,
TYPE_GIT,
TYPE_LOCAL,
)
@ -41,9 +42,13 @@ CODEOWNERS = ["@kahrendt", "@jesserockz"]
DEPENDENCIES = ["microphone"]
DOMAIN = "micro_wake_word"
CONF_MODELS = "models"
CONF_PROBABILITY_CUTOFF = "probability_cutoff"
CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size"
CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected"
CONF_VAD_MODEL = "vad_model"
CONF_UPPER = "upper"
CONF_LOWER = "lower"
TYPE_HTTP = "http"
@ -260,18 +265,42 @@ MODEL_SOURCE_SCHEMA = cv.Any(
msg="Not a valid model name, local path, http(s) url, or github shorthand",
)
MODEL_SCHEMA = cv.Schema(
{
cv.Required(CONF_MODEL): MODEL_SOURCE_SCHEMA,
cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage,
cv.Optional(CONF_SLIDING_WINDOW_AVERAGE_SIZE): cv.positive_int,
cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8),
}
)
VAD_MODEL_SCHEMA = cv.Schema(
{
cv.Required(CONF_MODEL): MODEL_SOURCE_SCHEMA,
cv.Optional(CONF_THRESHOLD, default=0.5): cv.Any(
cv.percentage,
cv.Schema(
{
cv.Required(CONF_UPPER): cv.percentage,
cv.Required(CONF_LOWER): cv.percentage,
}
),
),
cv.Optional(CONF_SLIDING_WINDOW_AVERAGE_SIZE): cv.positive_int,
cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8),
}
)
CONFIG_SCHEMA = cv.All(
cv.Schema(
{
cv.GenerateID(): cv.declare_id(MicroWakeWord),
cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone),
cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage,
cv.Optional(CONF_SLIDING_WINDOW_AVERAGE_SIZE): cv.positive_int,
cv.Required(CONF_MODELS): cv.ensure_list(MODEL_SCHEMA),
cv.Optional(CONF_ON_WAKE_WORD_DETECTED): automation.validate_automation(
single=True
),
cv.Required(CONF_MODEL): MODEL_SOURCE_SCHEMA,
cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8),
cv.Optional(CONF_VAD_MODEL): VAD_MODEL_SCHEMA,
}
).extend(cv.COMPONENT_SCHEMA),
cv.only_with_esp_idf,
@ -302,13 +331,6 @@ async def to_code(config):
mic = await cg.get_variable(config[CONF_MICROPHONE])
cg.add(var.set_microphone(mic))
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
await automation.build_automation(
var.get_wake_word_detected_trigger(),
[(cg.std_string, "wake_word")],
on_wake_word_detection_config,
)
esp32.add_idf_component(
name="esp-tflite-micro",
repo="https://github.com/espressif/esp-tflite-micro",
@ -318,39 +340,97 @@ async def to_code(config):
cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON")
cg.add_build_flag("-DESP_NN")
model_config = config.get(CONF_MODEL)
data = []
if model_config[CONF_TYPE] == TYPE_GIT:
# compute path to model file
key = f"{model_config[CONF_URL]}@{model_config.get(CONF_REF)}"
base_dir = Path(CORE.data_dir) / DOMAIN
h = hashlib.new("sha256")
h.update(key.encode())
file: Path = base_dir / h.hexdigest()[:8] / model_config[CONF_FILE]
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
await automation.build_automation(
var.get_wake_word_detected_trigger(),
[(cg.std_string, "wake_word")],
on_wake_word_detection_config,
)
elif model_config[CONF_TYPE] == TYPE_LOCAL:
file = model_config[CONF_PATH]
if vad_model := config.get(CONF_VAD_MODEL):
cg.add_define("USE_MWW_VAD")
model_config = vad_model.get(CONF_MODEL)
data = []
if model_config[CONF_TYPE] == TYPE_GIT:
# compute path to model file
key = f"{model_config[CONF_URL]}@{model_config.get(CONF_REF)}"
base_dir = Path(CORE.data_dir) / DOMAIN
h = hashlib.new("sha256")
h.update(key.encode())
file: Path = base_dir / h.hexdigest()[:8] / model_config[CONF_FILE]
elif model_config[CONF_TYPE] == TYPE_HTTP:
file = _compute_local_file_path(model_config) / "manifest.json"
elif model_config[CONF_TYPE] == TYPE_LOCAL:
file = Path(model_config[CONF_PATH])
manifest, data = _load_model_data(file)
elif model_config[CONF_TYPE] == TYPE_HTTP:
file = _compute_local_file_path(model_config) / "manifest.json"
rhs = [HexInt(x) for x in data]
prog_arr = cg.progmem_array(config[CONF_RAW_DATA_ID], rhs)
cg.add(var.set_model_start(prog_arr))
manifest, data = _load_model_data(file)
probability_cutoff = config.get(
CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF]
)
cg.add(var.set_probability_cutoff(probability_cutoff))
sliding_window_average_size = config.get(
CONF_SLIDING_WINDOW_AVERAGE_SIZE,
manifest[KEY_MICRO][CONF_SLIDING_WINDOW_AVERAGE_SIZE],
)
cg.add(var.set_sliding_window_average_size(sliding_window_average_size))
rhs = [HexInt(x) for x in data]
prog_arr = cg.progmem_array(vad_model[CONF_RAW_DATA_ID], rhs)
cg.add(var.set_wake_word(manifest[KEY_WAKE_WORD]))
sliding_window_average_size = vad_model.get(
CONF_SLIDING_WINDOW_AVERAGE_SIZE,
manifest[KEY_MICRO][CONF_SLIDING_WINDOW_AVERAGE_SIZE],
)
if isinstance(vad_model[CONF_THRESHOLD], float):
upper_threshold = vad_model[CONF_THRESHOLD]
lower_threshold = vad_model[CONF_THRESHOLD]
else:
upper_threshold = vad_model[CONF_THRESHOLD][CONF_UPPER]
lower_threshold = vad_model[CONF_THRESHOLD][CONF_LOWER]
cg.add(
var.add_vad_model(
prog_arr,
upper_threshold,
lower_threshold,
sliding_window_average_size,
22000, # Tensor arena size for VAD model
)
)
for model_parameters in config[CONF_MODELS]:
model_config = model_parameters.get(CONF_MODEL)
data = []
if model_config[CONF_TYPE] == TYPE_GIT:
# compute path to model file
key = f"{model_config[CONF_URL]}@{model_config.get(CONF_REF)}"
base_dir = Path(CORE.data_dir) / DOMAIN
h = hashlib.new("sha256")
h.update(key.encode())
file: Path = base_dir / h.hexdigest()[:8] / model_config[CONF_FILE]
elif model_config[CONF_TYPE] == TYPE_LOCAL:
file = Path(model_config[CONF_PATH])
elif model_config[CONF_TYPE] == TYPE_HTTP:
file = _compute_local_file_path(model_config) / "manifest.json"
manifest, data = _load_model_data(file)
rhs = [HexInt(x) for x in data]
prog_arr = cg.progmem_array(model_parameters[CONF_RAW_DATA_ID], rhs)
probability_cutoff = model_parameters.get(
CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF]
)
sliding_window_average_size = model_parameters.get(
CONF_SLIDING_WINDOW_AVERAGE_SIZE,
manifest[KEY_MICRO][CONF_SLIDING_WINDOW_AVERAGE_SIZE],
)
cg.add(
var.add_wake_word_model(
prog_arr,
probability_cutoff,
sliding_window_average_size,
manifest[KEY_WAKE_WORD],
45672, # Tensor arena size for original Inception-based models
)
)
MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)})

View File

@ -1,4 +1,5 @@
#include "micro_wake_word.h"
#include "streaming_model.h"
/**
* This is a workaround until we can figure out a way to get
@ -28,7 +29,7 @@ namespace micro_wake_word {
static const char *const TAG = "micro_wake_word";
static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz
static const size_t BUFFER_LENGTH = 500; // 0.5 seconds
static const size_t BUFFER_LENGTH = 100; // 0.1 seconds
static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH;
static const size_t INPUT_BUFFER_SIZE = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms
@ -54,32 +55,24 @@ static const LogString *micro_wake_word_state_to_string(State state) {
}
void MicroWakeWord::dump_config() {
ESP_LOGCONFIG(TAG, "microWakeWord:");
ESP_LOGCONFIG(TAG, " Wake Word: %s", this->get_wake_word().c_str());
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_average_size_);
ESP_LOGCONFIG(TAG, "microWakeWord models:");
for (auto &model : this->wake_word_models_) {
model.log_model_config();
}
#ifdef USE_MWW_VAD
this->vad_model_->log_model_config();
#endif
}
void MicroWakeWord::setup() {
ESP_LOGCONFIG(TAG, "Setting up microWakeWord...");
if (!this->initialize_models()) {
ESP_LOGE(TAG, "Failed to initialize models");
if (!this->register_streaming_ops_(this->streaming_op_resolver_)) {
this->mark_failed();
return;
}
ExternalRAMAllocator<int16_t> allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t));
if (this->input_buffer_ == nullptr) {
ESP_LOGW(TAG, "Could not allocate input buffer");
this->mark_failed();
return;
}
this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t));
if (this->ring_buffer_ == nullptr) {
ESP_LOGW(TAG, "Could not allocate ring buffer");
if (!this->register_preprocessor_ops_(this->preprocessor_op_resolver_)) {
this->mark_failed();
return;
}
@ -87,26 +80,21 @@ void MicroWakeWord::setup() {
ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
}
int MicroWakeWord::read_microphone_() {
size_t bytes_read = this->microphone_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t));
if (bytes_read == 0) {
return 0;
}
size_t bytes_free = this->ring_buffer_->free();
if (bytes_free < bytes_read) {
ESP_LOGW(TAG,
"Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
"Resetting the ring buffer. Wake word detection accuracy will be reduced.",
bytes_free, bytes_read);
this->ring_buffer_->reset();
}
return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read);
void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word,
size_t tensor_arena_size) {
this->wake_word_models_.push_back(
WakeWordModel(model_start, probability_cutoff, sliding_window_average_size, wake_word, tensor_arena_size));
}
#ifdef USE_MWW_VAD
void MicroWakeWord::add_vad_model(const uint8_t *model_start, float upper_threshold, float lower_threshold,
size_t sliding_window_size, size_t tensor_arena_size) {
this->vad_model_ =
new VADModel(model_start, upper_threshold, lower_threshold, sliding_window_size, tensor_arena_size);
}
#endif
void MicroWakeWord::loop() {
switch (this->state_) {
case State::IDLE:
@ -124,8 +112,9 @@ void MicroWakeWord::loop() {
break;
case State::DETECTING_WAKE_WORD:
this->read_microphone_();
if (this->detect_wake_word_()) {
ESP_LOGD(TAG, "Wake Word Detected");
this->update_model_probabilities_();
if (this->detect_wake_words_()) {
ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str());
this->detected_ = true;
this->set_state_(State::STOP_MICROPHONE);
}
@ -135,13 +124,16 @@ void MicroWakeWord::loop() {
this->microphone_->stop();
this->set_state_(State::STOPPING_MICROPHONE);
this->high_freq_.stop();
this->unload_models_();
this->deallocate_buffers_();
break;
case State::STOPPING_MICROPHONE:
if (this->microphone_->is_stopped()) {
this->set_state_(State::IDLE);
if (this->detected_) {
this->wake_word_detected_trigger_->trigger(this->detected_wake_word_);
this->detected_ = false;
this->wake_word_detected_trigger_->trigger(this->wake_word_);
this->detected_wake_word_ = "";
}
}
break;
@ -149,14 +141,34 @@ void MicroWakeWord::loop() {
}
void MicroWakeWord::start() {
if (!this->is_ready()) {
ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
return;
}
if (this->is_failed()) {
ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
return;
}
if (!this->load_models_() || !this->allocate_buffers_()) {
ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers");
this->status_set_error();
} else {
this->status_clear_error();
}
if (this->status_has_error()) {
ESP_LOGW(TAG, "Wake word component has an error. Please check logs");
return;
}
if (this->state_ != State::IDLE) {
ESP_LOGW(TAG, "Wake word is already running");
return;
}
this->reset_states_();
this->set_state_(State::START_MICROPHONE);
}
@ -178,207 +190,173 @@ void MicroWakeWord::set_state_(State state) {
this->state_ = state;
}
bool MicroWakeWord::initialize_models() {
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
ExternalRAMAllocator<int8_t> features_allocator(ExternalRAMAllocator<int8_t>::ALLOW_FAILURE);
size_t MicroWakeWord::read_microphone_() {
size_t bytes_read = this->microphone_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t));
if (bytes_read == 0) {
return 0;
}
size_t bytes_free = this->ring_buffer_->free();
if (bytes_free < bytes_read) {
ESP_LOGW(TAG,
"Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
"Resetting the ring buffer. Wake word detection accuracy will be reduced.",
bytes_free, bytes_read);
this->ring_buffer_->reset();
}
return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read);
}
bool MicroWakeWord::allocate_buffers_() {
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
this->streaming_tensor_arena_ = arena_allocator.allocate(STREAMING_MODEL_ARENA_SIZE);
if (this->streaming_tensor_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
return false;
if (this->input_buffer_ == nullptr) {
this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t));
if (this->input_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate input buffer");
return false;
}
}
this->streaming_var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
if (this->streaming_var_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the streaming model variable's tensor arena.");
return false;
}
this->preprocessor_tensor_arena_ = arena_allocator.allocate(PREPROCESSOR_ARENA_SIZE);
if (this->preprocessor_tensor_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio preprocessor model's tensor arena.");
return false;
}
this->new_features_data_ = features_allocator.allocate(PREPROCESSOR_FEATURE_SIZE);
if (this->new_features_data_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio features buffer.");
return false;
}
this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(SAMPLE_DURATION_COUNT);
if (this->preprocessor_audio_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer.");
return false;
this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(SAMPLE_DURATION_COUNT);
if (this->preprocessor_audio_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer.");
return false;
}
}
this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE);
if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) {
ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported");
return false;
}
this->streaming_model_ = tflite::GetModel(this->model_start_);
if (this->streaming_model_->version() != TFLITE_SCHEMA_VERSION) {
ESP_LOGE(TAG, "Wake word's streaming model's schema is not supported");
return false;
}
static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver;
static tflite::MicroMutableOpResolver<17> streaming_op_resolver;
if (!this->register_preprocessor_ops_(preprocessor_op_resolver))
return false;
if (!this->register_streaming_ops_(streaming_op_resolver))
return false;
tflite::MicroAllocator *ma =
tflite::MicroAllocator::Create(this->streaming_var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
this->mrv_ = tflite::MicroResourceVariables::Create(ma, 15);
static tflite::MicroInterpreter static_preprocessor_interpreter(
this->preprocessor_model_, preprocessor_op_resolver, this->preprocessor_tensor_arena_, PREPROCESSOR_ARENA_SIZE);
static tflite::MicroInterpreter static_streaming_interpreter(this->streaming_model_, streaming_op_resolver,
this->streaming_tensor_arena_,
STREAMING_MODEL_ARENA_SIZE, this->mrv_);
this->preprocessor_interperter_ = &static_preprocessor_interpreter;
this->streaming_interpreter_ = &static_streaming_interpreter;
// Allocate tensors for each models.
if (this->preprocessor_interperter_->AllocateTensors() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to allocate tensors for the audio preprocessor");
return false;
}
if (this->streaming_interpreter_->AllocateTensors() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
return false;
}
// Verify input tensor matches expected values
TfLiteTensor *input = this->streaming_interpreter_->input(0);
if ((input->dims->size != 3) || (input->dims->data[0] != 1) || (input->dims->data[0] != 1) ||
(input->dims->data[1] != 1) || (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
ESP_LOGE(TAG, "Wake word detection model tensor input dimensions is not 1x1x%u", input->dims->data[2]);
return false;
}
if (input->type != kTfLiteInt8) {
ESP_LOGE(TAG, "Wake word detection model tensor input is not int8.");
return false;
}
// Verify output tensor matches expected values
TfLiteTensor *output = this->streaming_interpreter_->output(0);
if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
ESP_LOGE(TAG, "Wake word detection model tensor output dimensions is not 1x1.");
}
if (output->type != kTfLiteUInt8) {
ESP_LOGE(TAG, "Wake word detection model tensor input is not uint8.");
return false;
}
this->recent_streaming_probabilities_.resize(this->sliding_window_average_size_, 0.0);
return true;
}
bool MicroWakeWord::update_features_() {
// Retrieve strided audio samples
int16_t *audio_samples = nullptr;
if (!this->stride_audio_samples_(&audio_samples)) {
return false;
}
// Compute the features for the newest audio samples
if (!this->generate_single_feature_(audio_samples, SAMPLE_DURATION_COUNT, this->new_features_data_)) {
return false;
if (this->ring_buffer_ == nullptr) {
this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t));
if (this->ring_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate ring buffer");
return false;
}
}
return true;
}
float MicroWakeWord::perform_streaming_inference_() {
TfLiteTensor *input = this->streaming_interpreter_->input(0);
size_t bytes_to_copy = input->bytes;
memcpy((void *) (tflite::GetTensorData<int8_t>(input)), (const void *) (this->new_features_data_), bytes_to_copy);
uint32_t prior_invoke = millis();
TfLiteStatus invoke_status = this->streaming_interpreter_->Invoke();
if (invoke_status != kTfLiteOk) {
ESP_LOGW(TAG, "Streaming Interpreter Invoke failed");
return false;
}
ESP_LOGV(TAG, "Streaming Inference Latency=%u ms", (millis() - prior_invoke));
TfLiteTensor *output = this->streaming_interpreter_->output(0);
return static_cast<float>(output->data.uint8[0]) / 255.0;
void MicroWakeWord::deallocate_buffers_() {
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
audio_samples_allocator.deallocate(this->input_buffer_, PREPROCESSOR_ARENA_SIZE);
this->input_buffer_ = nullptr;
audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, PREPROCESSOR_ARENA_SIZE);
this->preprocessor_audio_buffer_ = nullptr;
}
bool MicroWakeWord::detect_wake_word_() {
// Preprocess the newest audio samples into features
if (!this->update_features_()) {
bool MicroWakeWord::load_models_() {
// Setup preprocesor feature generator
if (this->preprocessor_tensor_arena_ == nullptr) {
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
this->preprocessor_tensor_arena_ = arena_allocator.allocate(PREPROCESSOR_ARENA_SIZE);
if (this->preprocessor_tensor_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio preprocessor model's tensor arena.");
return false;
}
}
if (this->preprocessor_interpreter_ == nullptr) {
this->preprocessor_interpreter_ = new tflite::MicroInterpreter(
tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE), this->preprocessor_op_resolver_,
this->preprocessor_tensor_arena_, PREPROCESSOR_ARENA_SIZE);
if (this->preprocessor_interpreter_->AllocateTensors() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to allocate tensors for the audio preprocessor");
return false;
}
}
// Setup streaming models
for (auto &model : this->wake_word_models_) {
if (!model.load_model(this->streaming_op_resolver_)) {
ESP_LOGE(TAG, "Failed to initialize a wake word model.");
return false;
}
}
#ifdef USE_MWW_VAD
if (!this->vad_model_->load_model(this->streaming_op_resolver_)) {
ESP_LOGE(TAG, "Failed to initialize VAD model.");
return false;
}
#endif
// Perform inference
float streaming_prob = this->perform_streaming_inference_();
return true;
}
// Add the most recent probability to the sliding window
this->recent_streaming_probabilities_[this->last_n_index_] = streaming_prob;
++this->last_n_index_;
if (this->last_n_index_ == this->sliding_window_average_size_)
this->last_n_index_ = 0;
void MicroWakeWord::unload_models_() {
delete (this->preprocessor_interpreter_);
this->preprocessor_interpreter_ = nullptr;
float sum = 0.0;
for (auto &prob : this->recent_streaming_probabilities_) {
sum += prob;
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
arena_allocator.deallocate(this->preprocessor_tensor_arena_, PREPROCESSOR_ARENA_SIZE);
this->preprocessor_tensor_arena_ = nullptr;
for (auto &model : this->wake_word_models_) {
model.unload_model();
}
#ifdef USE_MWW_VAD
this->vad_model_->unload_model();
#endif
}
void MicroWakeWord::update_model_probabilities_() {
if (!this->stride_audio_samples_()) {
return;
}
float sliding_window_average = sum / static_cast<float>(this->sliding_window_average_size_);
int8_t audio_features[PREPROCESSOR_FEATURE_SIZE];
// Ensure we have enough samples since the last positive detection
if (!this->generate_features_for_window_(audio_features)) {
return;
}
// Increase the counter since the last positive detection
this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
for (auto &model : this->wake_word_models_) {
// Perform inference
model.perform_streaming_inference(audio_features);
}
#ifdef USE_MWW_VAD
this->vad_model_->perform_streaming_inference(audio_features);
#endif
}
bool MicroWakeWord::detect_wake_words_() {
// Verify we have processed samples since the last positive detection
if (this->ignore_windows_ < 0) {
return false;
}
// Detect the wake word if the sliding window average is above the cutoff
if (sliding_window_average > this->probability_cutoff_) {
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
for (auto &prob : this->recent_streaming_probabilities_) {
prob = 0;
}
#ifdef USE_MWW_VAD
bool vad_state = this->vad_model_->determine_detected();
#endif
ESP_LOGD(TAG, "Wake word sliding average probability is %.3f and most recent probability is %.3f",
sliding_window_average, streaming_prob);
return true;
for (auto &model : this->wake_word_models_) {
if (model.determine_detected()) {
#ifdef USE_MWW_VAD
if (vad_state) {
#endif
this->detected_wake_word_ = model.get_wake_word();
return true;
#ifdef USE_MWW_VAD
} else {
ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str());
}
#endif
}
}
return false;
}
void MicroWakeWord::set_sliding_window_average_size(size_t size) {
this->sliding_window_average_size_ = size;
this->recent_streaming_probabilities_.resize(this->sliding_window_average_size_, 0.0);
}
bool MicroWakeWord::slice_available_() {
size_t available = this->ring_buffer_->available();
return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t));
}
bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) {
if (!this->slice_available_()) {
bool MicroWakeWord::stride_audio_samples_() {
// Ensure we have enough new audio samples in the ring buffer for a full window
if (this->ring_buffer_->available() < NEW_SAMPLES_TO_GET * sizeof(int16_t)) {
return false;
}
@ -400,25 +378,35 @@ bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) {
return false;
}
*audio_samples = this->preprocessor_audio_buffer_;
return true;
}
bool MicroWakeWord::generate_single_feature_(const int16_t *audio_data, const int audio_data_size,
int8_t feature_output[PREPROCESSOR_FEATURE_SIZE]) {
TfLiteTensor *input = this->preprocessor_interperter_->input(0);
TfLiteTensor *output = this->preprocessor_interperter_->output(0);
std::copy_n(audio_data, audio_data_size, tflite::GetTensorData<int16_t>(input));
bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
TfLiteTensor *input = this->preprocessor_interpreter_->input(0);
TfLiteTensor *output = this->preprocessor_interpreter_->output(0);
std::copy_n(this->preprocessor_audio_buffer_, SAMPLE_DURATION_COUNT, tflite::GetTensorData<int16_t>(input));
if (this->preprocessor_interperter_->Invoke() != kTfLiteOk) {
if (this->preprocessor_interpreter_->Invoke() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to preprocess audio for local wake word.");
return false;
}
std::memcpy(feature_output, tflite::GetTensorData<int8_t>(output), PREPROCESSOR_FEATURE_SIZE * sizeof(int8_t));
std::memcpy(features, tflite::GetTensorData<int8_t>(output), PREPROCESSOR_FEATURE_SIZE * sizeof(int8_t));
return true;
}
void MicroWakeWord::reset_states_() {
ESP_LOGD(TAG, "Resetting buffers and probabilities");
this->ring_buffer_->reset();
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
for (auto &model : this->wake_word_models_) {
model.reset_probabilities();
}
#ifdef USE_MWW_VAD
this->vad_model_->reset_probabilities();
#endif
}
bool MicroWakeWord::register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver) {
if (op_resolver.AddReshape() != kTfLiteOk)
return false;

View File

@ -10,6 +10,9 @@
#ifdef USE_ESP_IDF
#include "preprocessor_settings.h"
#include "streaming_model.h"
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/ring_buffer.h"
@ -23,35 +26,6 @@
namespace esphome {
namespace micro_wake_word {
// The following are dictated by the preprocessor model
//
// The number of features the audio preprocessor generates per slice
static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40;
// How frequently the preprocessor generates a new set of features
static const uint8_t FEATURE_STRIDE_MS = 20;
// Duration of each slice used as input into the preprocessor
static const uint8_t FEATURE_DURATION_MS = 30;
// Audio sample frequency in hertz
static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000;
// The number of old audio samples that are saved to be part of the next feature window
static const uint16_t HISTORY_SAMPLES_TO_KEEP =
((FEATURE_DURATION_MS - FEATURE_STRIDE_MS) * (AUDIO_SAMPLE_FREQUENCY / 1000));
// The number of new audio samples to receive to be included with the next feature window
static const uint16_t NEW_SAMPLES_TO_GET = (FEATURE_STRIDE_MS * (AUDIO_SAMPLE_FREQUENCY / 1000));
// The total number of audio samples included in the feature window
static const uint16_t SAMPLE_DURATION_COUNT = FEATURE_DURATION_MS * AUDIO_SAMPLE_FREQUENCY / 1000;
// Number of bytes in memory needed for the preprocessor arena
static const uint32_t PREPROCESSOR_ARENA_SIZE = 9528;
// The following configure the streaming wake word model
//
// The number of audio slices to process before accepting a positive detection
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74;
// Number of bytes in memory needed for the streaming wake word model
static const uint32_t STREAMING_MODEL_ARENA_SIZE = 64000;
static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
enum State {
IDLE,
START_MICROPHONE,
@ -61,6 +35,9 @@ enum State {
STOPPING_MICROPHONE,
};
// The number of audio slices to process before accepting a positive detection
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74;
class MicroWakeWord : public Component {
public:
void setup() override;
@ -73,28 +50,19 @@ class MicroWakeWord : public Component {
bool is_running() const { return this->state_ != State::IDLE; }
bool initialize_models();
std::string get_wake_word() { return this->wake_word_; }
// Increasing either of these will reduce the rate of false acceptances while increasing the false rejection rate
void set_probability_cutoff(float probability_cutoff) { this->probability_cutoff_ = probability_cutoff; }
void set_sliding_window_average_size(size_t size);
void set_microphone(microphone::Microphone *microphone) { this->microphone_ = microphone; }
Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; }
void set_model_start(const uint8_t *model_start) { this->model_start_ = model_start; }
void set_wake_word(const std::string &wake_word) { this->wake_word_ = wake_word; }
void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size);
#ifdef USE_MWW_VAD
void add_vad_model(const uint8_t *model_start, float upper_threshold, float lower_threshold,
size_t sliding_window_size, size_t tensor_arena_size);
#endif
protected:
void set_state_(State state);
int read_microphone_();
const uint8_t *model_start_;
std::string wake_word_;
microphone::Microphone *microphone_{nullptr};
Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>();
State state_{State::IDLE};
@ -102,79 +70,92 @@ class MicroWakeWord : public Component {
std::unique_ptr<RingBuffer> ring_buffer_;
int16_t *input_buffer_;
std::vector<WakeWordModel> wake_word_models_;
const tflite::Model *preprocessor_model_{nullptr};
const tflite::Model *streaming_model_{nullptr};
tflite::MicroInterpreter *streaming_interpreter_{nullptr};
tflite::MicroInterpreter *preprocessor_interperter_{nullptr};
#ifdef USE_MWW_VAD
VADModel *vad_model_;
#endif
std::vector<float> recent_streaming_probabilities_;
size_t last_n_index_{0};
tflite::MicroMutableOpResolver<17> streaming_op_resolver_;
tflite::MicroMutableOpResolver<18> preprocessor_op_resolver_;
float probability_cutoff_{0.5};
size_t sliding_window_average_size_{10};
tflite::MicroInterpreter *preprocessor_interpreter_{nullptr};
// When the wake word detection first starts or after the word has been detected once, we ignore this many audio
// feature slices before accepting a positive detection again
// When the wake word detection first starts, we ignore this many audio
// feature slices before accepting a positive detection
int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
uint8_t *streaming_var_arena_{nullptr};
uint8_t *streaming_tensor_arena_{nullptr};
uint8_t *preprocessor_tensor_arena_{nullptr};
int8_t *new_features_data_{nullptr};
tflite::MicroResourceVariables *mrv_{nullptr};
// Stores audio fed into feature generator preprocessor
int16_t *preprocessor_audio_buffer_;
// Stores audio read from the microphone before being added to the ring buffer.
int16_t *input_buffer_{nullptr};
// Stores audio fed into feature generator preprocessor. Also used for striding samples in each window.
int16_t *preprocessor_audio_buffer_{nullptr};
bool detected_{false};
std::string detected_wake_word_{""};
/** Detects if wake word has been said
void set_state_(State state);
/** Reads audio from microphone into the ring buffer
*
* Audio data (16000 kHz with int16 samples) is read into the input_buffer_.
* Verifies the ring buffer has enough space for all audio data. If not, it logs
* a warning and resets the ring buffer entirely.
* @return Number of bytes written to the ring buffer
*/
size_t read_microphone_();
/// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_
/// @return True if successful, false otherwise
bool allocate_buffers_();
/// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_
void deallocate_buffers_();
/// @brief Loads streaming models
/// @return True if successful, false otherwise
bool load_models_();
/// @brief Deletes each model's TFLite interpreters and frees tensor arena memory
void unload_models_();
/** Performs inference with each configured model
*
* If enough audio samples are available, it will generate one slice of new features.
* If the streaming model predicts the wake word, then the nonstreaming model confirms it.
* @param ring_Buffer Ring buffer containing raw audio samples
* @return True if the wake word is detected, false otherwise
* It then loops through and performs inference with each of the loaded models.
*/
bool detect_wake_word_();
void update_model_probabilities_();
/// @brief Returns true if there are enough audio samples in the buffer to generate another slice of features
bool slice_available_();
/** Shifts previous feature slices over by one and generates a new slice of features
/** Checks every model's recent probabilities to determine if the wake word has been predicted
*
* @param ring_buffer ring buffer containing raw audio samples
* @return True if a new slice of features was generated, false otherwise
* Verifies the models have processed enough new samples for accurate predictions.
* Sets detected_wake_word_ to the wake word, if one is detected.
* @return True if a wake word is predicted, false otherwise
*/
bool update_features_();
bool detect_wake_words_();
/** Generates features from audio samples
/** Reads in new audio data from ring buffer to create the next sample window
*
* Adapted from TFLite micro speech example
* @param audio_data Pointer to array with the audio samples
* @param audio_data_size The number of samples to use as input to the preprocessor model
* @param feature_output Array that will store the features
* @return True if successful, false otherwise.
*/
bool generate_single_feature_(const int16_t *audio_data, int audio_data_size,
int8_t feature_output[PREPROCESSOR_FEATURE_SIZE]);
/** Performs inference over the most recent feature slice with the streaming model
*
* @return Probability of the wake word between 0.0 and 1.0
*/
float perform_streaming_inference_();
/** Strides the audio samples by keeping the last 10 ms of the previous slice
*
* Adapted from the TFLite micro speech example
* @param ring_buffer Ring buffer containing raw audio samples
* @param audio_samples Pointer to an array that will store the strided audio samples
* Moves the last 10 ms of audio from the previous window to the start of the new window.
* The next 20 ms of audio is copied from the ring buffer and inserted into the new window.
* The new window's audio samples are stored in preprocessor_audio_buffer_.
* Adapted from the TFLite micro speech example.
* @return True if successful, false otherwise
*/
bool stride_audio_samples_(int16_t **audio_samples);
bool stride_audio_samples_();
/** Generates features for a window of audio samples
*
* Feeds the strided audio samples in preprocessor_audio_buffer_ into the preprocessor.
* Adapted from TFLite micro speech example.
* @param features int8_t array to store the audio features
* @return True if successful, false otherwise.
*/
bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]);
/// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities
void reset_states_();
/// @brief Returns true if successfully registered the preprocessor's TensorFlow operations
bool register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver);

View File

@ -0,0 +1,31 @@
#pragma once
#ifdef USE_ESP_IDF
#include <cstdint>
namespace esphome {
namespace micro_wake_word {
// The number of features the audio preprocessor generates per slice
static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40;
// How frequently the preprocessor generates a new set of features
static const uint8_t FEATURE_STRIDE_MS = 20;
// Duration of each slice used as input into the preprocessor
static const uint8_t FEATURE_DURATION_MS = 30;
// Audio sample frequency in hertz
static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000;
// The number of old audio samples that are saved to be part of the next feature window
static const uint16_t HISTORY_SAMPLES_TO_KEEP =
((FEATURE_DURATION_MS - FEATURE_STRIDE_MS) * (AUDIO_SAMPLE_FREQUENCY / 1000));
// The number of new audio samples to receive to be included with the next feature window
static const uint16_t NEW_SAMPLES_TO_GET = (FEATURE_STRIDE_MS * (AUDIO_SAMPLE_FREQUENCY / 1000));
// The total number of audio samples included in the feature window
static const uint16_t SAMPLE_DURATION_COUNT = FEATURE_DURATION_MS * AUDIO_SAMPLE_FREQUENCY / 1000;
// Number of bytes in memory needed for the preprocessor arena
static const uint32_t PREPROCESSOR_ARENA_SIZE = 9528;
} // namespace micro_wake_word
} // namespace esphome
#endif

View File

@ -0,0 +1,211 @@
/**
* This is a workaround until we can figure out a way to get
* the tflite-micro idf component code available in CI
*
* */
//
#ifndef CLANG_TIDY
#ifdef USE_ESP_IDF
#include "streaming_model.h"
#include "esphome/core/hal.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
static const char *const TAG = "micro_wake_word";
namespace esphome {
namespace micro_wake_word {
void WakeWordModel::log_model_config() {
ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str());
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
}
void VADModel::log_model_config() {
ESP_LOGCONFIG(TAG, " - VAD Model");
ESP_LOGCONFIG(TAG, " Upper threshold: %.3f", this->upper_threshold_);
ESP_LOGCONFIG(TAG, " Lower threshold: %.3f", this->lower_threshold_);
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
}
bool StreamingModel::load_model(tflite::MicroMutableOpResolver<17> &op_resolver) {
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
if (this->tensor_arena_ == nullptr) {
this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
if (this->tensor_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
return false;
}
}
if (this->var_arena_ == nullptr) {
this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
if (this->var_arena_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the streaming model's variable tensor arena.");
return false;
}
this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
}
const tflite::Model *model = tflite::GetModel(this->model_start_);
if (model->version() != TFLITE_SCHEMA_VERSION) {
ESP_LOGE(TAG, "Streaming model's schema is not supported");
return false;
}
if (this->interpreter_ == nullptr) {
this->interpreter_ = new tflite::MicroInterpreter(tflite::GetModel(this->model_start_), op_resolver,
this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
return false;
}
// Verify input tensor matches expected values
TfLiteTensor *input = this->interpreter_->input(0);
if ((input->dims->size != 3) || (input->dims->data[0] != 1) || (input->dims->data[0] != 1) ||
(input->dims->data[1] != 1) || (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
ESP_LOGE(TAG, "Streaming model tensor input dimensions is not 1x1x%u", PREPROCESSOR_FEATURE_SIZE);
return false;
}
if (input->type != kTfLiteInt8) {
ESP_LOGE(TAG, "Streaming model tensor input is not int8.");
return false;
}
// Verify output tensor matches expected values
TfLiteTensor *output = this->interpreter_->output(0);
if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
ESP_LOGE(TAG, "Streaming model tensor output dimension is not 1x1.");
}
if (output->type != kTfLiteUInt8) {
ESP_LOGE(TAG, "Streaming model tensor output is not uint8.");
return false;
}
}
return true;
}
void StreamingModel::unload_model() {
delete (this->interpreter_);
this->interpreter_ = nullptr;
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
this->tensor_arena_ = nullptr;
arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
this->var_arena_ = nullptr;
}
bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
if (this->interpreter_ != nullptr) {
TfLiteTensor *input = this->interpreter_->input(0);
size_t bytes_to_copy = input->bytes;
memcpy((void *) (tflite::GetTensorData<int8_t>(input)), (const void *) (features), bytes_to_copy);
TfLiteStatus invoke_status = this->interpreter_->Invoke();
if (invoke_status != kTfLiteOk) {
ESP_LOGW(TAG, "Streaming interpreter invoke failed");
return false;
}
TfLiteTensor *output = this->interpreter_->output(0);
++this->last_n_index_;
if (this->last_n_index_ == this->sliding_window_size_)
this->last_n_index_ = 0;
this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
return true;
}
ESP_LOGE(TAG, "Streaming interpreter is not initialized.");
return false;
}
void StreamingModel::reset_probabilities() {
for (auto &prob : this->recent_streaming_probabilities_) {
prob = 0;
}
}
WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size) {
this->model_start_ = model_start;
this->probability_cutoff_ = probability_cutoff;
this->sliding_window_size_ = sliding_window_average_size;
this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
this->wake_word_ = wake_word;
this->tensor_arena_size_ = tensor_arena_size;
};
VADModel::VADModel(const uint8_t *model_start, float upper_threshold, float lower_threshold, size_t sliding_window_size,
size_t tensor_arena_size) {
this->model_start_ = model_start;
this->upper_threshold_ = upper_threshold;
this->lower_threshold_ = lower_threshold;
this->sliding_window_size_ = sliding_window_size;
this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
this->tensor_arena_size_ = tensor_arena_size;
};
bool WakeWordModel::determine_detected() {
int32_t sum = 0;
for (auto &prob : this->recent_streaming_probabilities_) {
sum += prob;
}
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
// Detect the wake word if the sliding window average is above the cutoff
if (sliding_window_average > this->probability_cutoff_) {
ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f",
this->wake_word_.c_str(), sliding_window_average,
this->recent_streaming_probabilities_[this->last_n_index_] / (255.0));
return true;
}
return false;
}
bool VADModel::determine_detected() {
int32_t sum = 0;
for (auto &prob : this->recent_streaming_probabilities_) {
sum += prob;
}
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
if (sliding_window_average > this->upper_threshold_) {
this->vad_state_ = true;
this->clear_countdown_ = 10;
return true;
} else if ((this->vad_state_) && (sliding_window_average > this->lower_threshold_)) {
return true;
} else {
if (this->clear_countdown_ > 0) {
--this->clear_countdown_;
return true;
}
}
this->vad_state_ = false;
return false;
}
} // namespace micro_wake_word
} // namespace esphome
#endif
#endif

View File

@ -0,0 +1,90 @@
#pragma once
/**
* This is a workaround until we can figure out a way to get
* the tflite-micro idf component code available in CI
*
* */
//
#ifndef CLANG_TIDY
#ifdef USE_ESP_IDF
#include "preprocessor_settings.h"
#include <tensorflow/lite/core/c/common.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
namespace esphome {
namespace micro_wake_word {
static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
class StreamingModel {
public:
virtual void log_model_config() = 0;
virtual bool determine_detected() = 0;
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
/// @brief Sets all recent_streaming_probabilities to 0
void reset_probabilities();
/// @brief Allocates tensor and variable arenas and sets up the model interpreter
/// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded
/// @return True if successful, false otherwise
bool load_model(tflite::MicroMutableOpResolver<17> &op_resolver);
/// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory
void unload_model();
protected:
size_t sliding_window_size_;
size_t last_n_index_{0};
size_t tensor_arena_size_;
std::vector<uint8_t> recent_streaming_probabilities_;
const uint8_t *model_start_;
uint8_t *tensor_arena_{nullptr};
uint8_t *var_arena_{nullptr};
tflite::MicroInterpreter *interpreter_{nullptr};
tflite::MicroResourceVariables *mrv_{nullptr};
tflite::MicroAllocator *ma_{nullptr};
};
class WakeWordModel : public StreamingModel {
public:
WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size);
void log_model_config() override;
bool determine_detected() override;
std::string get_wake_word() { return this->wake_word_; }
protected:
float probability_cutoff_;
std::string wake_word_;
};
class VADModel : public StreamingModel {
public:
VADModel(const uint8_t *model_start, float upper_threshold, float lower_threshold, size_t sliding_window_size,
size_t tensor_arena_size);
void log_model_config() override;
bool determine_detected() override;
protected:
uint8_t clear_countdown_{10};
bool vad_state_{false};
float upper_threshold_;
float lower_threshold_;
};
} // namespace micro_wake_word
} // namespace esphome
#endif
#endif

View File

@ -10,6 +10,10 @@ microphone:
pdm: true
micro_wake_word:
model: hey_jarvis
on_wake_word_detected:
- logger.log: "Wake word detected"
models:
- model: hey_jarvis
probability_cutoff: 0.7
- model: okay_nabu
sliding_window_average_size: 5