From 76e7a0afc13e30e2903bca4bfcdedc74da30742f Mon Sep 17 00:00:00 2001 From: Prabhsimran Singh Date: Tue, 18 Feb 2020 18:03:19 +0530 Subject: [PATCH] clean up decoder --- src/decoder.hpp | 563 +++++++++++++++++++++++------------------------- src/server.hpp | 58 +++-- 2 files changed, 289 insertions(+), 332 deletions(-) diff --git a/src/decoder.hpp b/src/decoder.hpp index 61b523b..6294104 100644 --- a/src/decoder.hpp +++ b/src/decoder.hpp @@ -30,6 +30,7 @@ // local includes #include "utils.hpp" + struct Word { float start_time, end_time, confidence; std::string word; @@ -52,14 +53,16 @@ struct DecoderOptions { // Result for one continuous utterance using utterance_results_t = std::vector; + // Find confidence by merging lm and am scores. Taken from -// https://github.com/dialogflow/asr-server/blob/master/src/OnlineDecoder.cc#L90 +// https://github.com/dialogflow/asr-server/blob/master/src/Onlinedecoder_->cc#L90 // NOTE: This might not be very useful for us right now. Depending on the // situation, we might actually want to weigh components differently. inline double calculate_confidence(const float &lm_score, const float &am_score, const std::size_t &n_words) noexcept { return std::max(0.0, std::min(1.0, -0.0001466488 * (2.388449 * lm_score + am_score) / (n_words + 1) + 0.956)); } + inline void print_wav_info(const kaldi::WaveInfo &wave_info) noexcept { std::cout << "sample freq: " << wave_info.SampFreq() << ENDL << "sample count: " << wave_info.SampleCount() << ENDL @@ -70,6 +73,7 @@ inline void print_wav_info(const kaldi::WaveInfo &wave_info) noexcept { << "block align: " << wave_info.BlockAlign() << ENDL; } + void read_raw_wav_stream(std::istream &wav_stream, const size_t &data_bytes, kaldi::Matrix &wav_data) { @@ -104,83 +108,92 @@ void read_raw_wav_stream(std::istream &wav_stream, } } -class Decoder final { - - private: - std::unique_ptr word_syms_; - kaldi::WordBoundaryInfo* wb_info_; - - public: - fst::Fst *const decode_fst_; - mutable kaldi::nnet3::AmNnetSimple am_nnet_; - kaldi::TransitionModel trans_model_; - DecoderOptions options; - std::unique_ptr feature_info_; - std::unique_ptr decodable_info_; +class Decoder final { - kaldi::LatticeFasterDecoderConfig lattice_faster_decoder_config_; - kaldi::nnet3::NnetSimpleLoopedComputationOptions decodable_opts_; + public: + explicit Decoder(const kaldi::BaseFloat &beam, + const std::size_t &min_active, + const std::size_t &max_active, + const kaldi::BaseFloat &lattice_beam, + const kaldi::BaseFloat &acoustic_scale, + const std::size_t &frame_subsampling_factor, + const std::string &model_dir) noexcept; - explicit Decoder(const kaldi::BaseFloat &, const std::size_t &, - const std::size_t &, const kaldi::BaseFloat &, - const kaldi::BaseFloat &, const std::size_t &, - const std::string &, fst::Fst *const) noexcept; + ~Decoder() noexcept; - ~Decoder(); + // SETUP METHODS + void start_decoding() noexcept; - void _find_alternatives(const kaldi::CompactLattice &clat, - const std::size_t &n_best, - utterance_results_t &results, - const bool &word_level) const noexcept; + void free_decoder() noexcept; - // Decoding processes - void _decode_wave(kaldi::OnlineNnet2FeaturePipeline &, - kaldi::OnlineSilenceWeighting &, - kaldi::SingleUtteranceNnet3Decoder &, - kaldi::SubVector &, - std::vector> &, - const kaldi::BaseFloat &) const; + // STREAMING METHODS // decode an intermediate frame/chunk of a wav audio stream - void decode_stream_wav_chunk(kaldi::OnlineNnet2FeaturePipeline &, - kaldi::OnlineSilenceWeighting &, - kaldi::SingleUtteranceNnet3Decoder &, - std::istream &) const; + void decode_stream_wav_chunk(std::istream &wav_stream); // decode an intermediate frame/chunk of a raw headerless wav audio stream - void decode_stream_raw_wav_chunk(kaldi::OnlineNnet2FeaturePipeline &, - kaldi::OnlineSilenceWeighting &, - kaldi::SingleUtteranceNnet3Decoder &, - std::istream &, - const kaldi::BaseFloat &, - const size_t &) const; + void decode_stream_raw_wav_chunk(std::istream &wav_stream, + const kaldi::BaseFloat& samp_freq, + const size_t &data_bytes); + + // NON-STREAMING METHODS // decodes an (independent) wav audio stream // internally chunks a wav audio stream and decodes them - void decode_wav_audio(std::istream &, - const size_t &, - utterance_results_t &, - const bool &, - const kaldi::BaseFloat & = 1) const; + void decode_wav_audio(std::istream &wav_stream, + const kaldi::BaseFloat &chunk_size=1); // decodes an (independent) raw headerless wav audio stream // internally chunks a wav audio stream and decodes them - void decode_raw_wav_audio(std::istream &, - const kaldi::BaseFloat &, - const size_t &, - const size_t &, - utterance_results_t &, - const bool &, - const kaldi::BaseFloat & = 1) const; + void decode_raw_wav_audio(std::istream &wav_stream, + const kaldi::BaseFloat &samp_freq, + const size_t &data_bytes, + const kaldi::BaseFloat &chunk_size=1); + + // LATTICE DECODING METHODS // get the final utterances based on the compact lattice - void decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &, - kaldi::SingleUtteranceNnet3Decoder &, - const std::size_t &, - utterance_results_t &, - const bool &, - const bool & = false) const; + void get_decoded_results(const std::size_t &n_best, + utterance_results_t &results, + const bool &word_level, + const bool &bidi_streaming=false); + + DecoderOptions options; + + private: + // decodes an intermediate wavepart + void _decode_wave(kaldi::SubVector &wave_part, + std::vector> &delta_weights, + const kaldi::BaseFloat &samp_freq); + + // gets the final decoded transcripts from lattice + void _find_alternatives(const kaldi::CompactLattice &clat, + const std::size_t &n_best, + utterance_results_t &results, + const bool &word_level) const; + + // model vars + std::unique_ptr> decode_fst_; + kaldi::nnet3::AmNnetSimple am_nnet_; + kaldi::TransitionModel trans_model_; + + std::unique_ptr word_syms_; + + std::unique_ptr wb_info_; + std::unique_ptr feature_info_; + + kaldi::LatticeFasterDecoderConfig lattice_faster_decoder_config_; + kaldi::nnet3::NnetSimpleLoopedComputationOptions decodable_opts_; + + // decoder vars (per utterance) + kaldi::SingleUtteranceNnet3Decoder *decoder_; + kaldi::OnlineNnet2FeaturePipeline *feature_pipeline_; + + // decoder vars (per decoder) + std::unique_ptr adaptation_state_; + std::unique_ptr silence_weighting_; + std::unique_ptr decodable_info_; }; Decoder::Decoder(const kaldi::BaseFloat &beam, @@ -189,10 +202,7 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, const kaldi::BaseFloat &lattice_beam, const kaldi::BaseFloat &acoustic_scale, const std::size_t &frame_subsampling_factor, - const std::string &model_dir, - fst::Fst *const decode_fst) noexcept - : decode_fst_(decode_fst) { - + const std::string &model_dir) noexcept { try { lattice_faster_decoder_config_.min_active = min_active; lattice_faster_decoder_config_.max_active = max_active; @@ -201,13 +211,17 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, decodable_opts_.acoustic_scale = acoustic_scale; decodable_opts_.frame_subsampling_factor = frame_subsampling_factor; + std::string hclg_filepath = join_path(model_dir, "HCLG.fst"); + std::string model_filepath = join_path(model_dir, "final.mdl"); std::string word_syms_filepath = join_path(model_dir, "words.txt"); std::string word_boundary_filepath = join_path(model_dir, "word_boundary.int"); - std::string model_filepath = join_path(model_dir, "final.mdl"); + std::string conf_dir = join_path(model_dir, "conf"); std::string mfcc_conf_filepath = join_path(conf_dir, "mfcc.conf"); std::string ivector_conf_filepath = join_path(conf_dir, "ivector_extractor.conf"); + decode_fst_ = std::unique_ptr>(fst::ReadFstKaldiGeneric(hclg_filepath)); + { bool binary; kaldi::Input ki(model_filepath, &binary); @@ -226,7 +240,7 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, if (exists(word_boundary_filepath)) { kaldi::WordBoundaryInfoNewOpts word_boundary_opts; - wb_info_ = new kaldi::WordBoundaryInfo(word_boundary_opts, word_boundary_filepath); + wb_info_ = std::make_unique(word_boundary_opts, word_boundary_filepath); options.enable_word_level = true; } else { KALDI_WARN << "Word boundary file" << word_boundary_filepath @@ -253,24 +267,181 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, feature_info_->ivector_extractor_info.Init(ivector_extraction_opts); + // decoder vars initialization + decoder_ = NULL; + feature_pipeline_ = NULL; + adaptation_state_ = std::make_unique(feature_info_->ivector_extractor_info); + silence_weighting_ = std::make_unique(trans_model_, + feature_info_->silence_weighting_config, + decodable_opts_.frame_subsampling_factor); decodable_info_ = std::make_unique(decodable_opts_, &am_nnet_); - } - catch (const std::exception &e) { + + } catch (const std::exception &e) { KALDI_ERR << e.what(); } } -Decoder::~Decoder() { - if (wb_info_ != nullptr) delete wb_info_; +Decoder::~Decoder() noexcept { + free_decoder(); +} + +void Decoder::start_decoding() noexcept { + free_decoder(); + + feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline(*feature_info_); + feature_pipeline_->SetAdaptationState(*adaptation_state_); + + decoder_ = new kaldi::SingleUtteranceNnet3Decoder(lattice_faster_decoder_config_, + trans_model_, *decodable_info_, + *decode_fst_, feature_pipeline_); +} + +void Decoder::free_decoder() noexcept { + if (decoder_) { + delete decoder_; + decoder_ = NULL; + } + if (feature_pipeline_) { + delete feature_pipeline_; + feature_pipeline_ = NULL; + } +} + +void Decoder::decode_stream_wav_chunk(std::istream &wav_stream) { + kaldi::WaveData wave_data; + wave_data.Read(wav_stream); + + const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); + + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + kaldi::SubVector wave_part(wave_data.Data(), 0); + std::vector> delta_weights; + _decode_wave(wave_part, delta_weights, samp_freq); } +void Decoder::decode_stream_raw_wav_chunk(std::istream &wav_stream, + const kaldi::BaseFloat& samp_freq, + const size_t &data_bytes) { + kaldi::Matrix wave_matrix; + read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); + + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + kaldi::SubVector wave_part(wave_matrix, 0); + std::vector> delta_weights; + _decode_wave(wave_part, delta_weights, samp_freq); +} + +void Decoder::decode_wav_audio(std::istream &wav_stream, + const kaldi::BaseFloat &chunk_size) { + kaldi::WaveData wave_data; + wave_data.Read(wav_stream); + + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + kaldi::SubVector data(wave_data.Data(), 0); + const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); + + int32 chunk_length; + if (chunk_size > 0) { + chunk_length = int32(samp_freq * chunk_size); + if (chunk_length == 0) + chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector> delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; + + kaldi::SubVector wave_part(data, samp_offset, num_samp); + _decode_wave(wave_part, delta_weights, samp_freq); + + samp_offset += num_samp; + } +} + +void Decoder::decode_raw_wav_audio(std::istream &wav_stream, + const kaldi::BaseFloat &samp_freq, + const size_t &data_bytes, + const kaldi::BaseFloat &chunk_size) { + kaldi::Matrix wave_matrix; + read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); + + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + kaldi::SubVector data(wave_matrix, 0); + + int32 chunk_length; + if (chunk_size > 0) { + chunk_length = int32(samp_freq * chunk_size); + if (chunk_length == 0) + chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector> delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; + + kaldi::SubVector wave_part(data, samp_offset, num_samp); + _decode_wave(wave_part, delta_weights, samp_freq); + + samp_offset += num_samp; + } +} + +void Decoder::get_decoded_results(const std::size_t &n_best, + utterance_results_t &results, + const bool &word_level, + const bool &bidi_streaming) { + if (!bidi_streaming) { + feature_pipeline_->InputFinished(); + decoder_->FinalizeDecoding(); + } + + if (decoder_->NumFramesDecoded() == 0) { + KALDI_WARN << "audio may be empty :: decoded no frames"; + return; + } + + kaldi::CompactLattice clat; + try { + decoder_->GetLattice(true, &clat); + _find_alternatives(clat, n_best, results, word_level); + } catch (std::exception &e) { + KALDI_ERR << "unexpected error during decoding lattice :: " << e.what(); + } +} + +void Decoder::_decode_wave(kaldi::SubVector &wave_part, + std::vector> &delta_weights, + const kaldi::BaseFloat &samp_freq) { + + feature_pipeline_->AcceptWaveform(samp_freq, wave_part); + + if (silence_weighting_->Active() && feature_pipeline_->IvectorFeature() != NULL) { + silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder()); + silence_weighting_->GetDeltaWeights(feature_pipeline_->NumFramesReady(), + &delta_weights); + feature_pipeline_->IvectorFeature()->UpdateFrameWeights(delta_weights); + } + decoder_->AdvanceDecoding(); +} -// Computes n-best alternative from lattice. Output symbols are converted to words -// based on word-syms. void Decoder::_find_alternatives(const kaldi::CompactLattice &clat, const std::size_t &n_best, utterance_results_t &results, - const bool &word_level) const noexcept { + const bool &word_level) const { if (clat.NumStates() == 0) { KALDI_LOG << "Empty lattice."; } @@ -280,7 +451,7 @@ void Decoder::_find_alternatives(const kaldi::CompactLattice &clat, kaldi::Lattice nbest_lat; std::vector nbest_lats; - fst::ShortestPath(*lat.get(), &nbest_lat, n_best); + fst::ShortestPath(*lat, &nbest_lat, n_best); fst::ConvertNbestToVector(nbest_lat, &nbest_lats); if (nbest_lats.empty()) { @@ -380,232 +551,34 @@ void Decoder::_find_alternatives(const kaldi::CompactLattice &clat, } } -void Decoder::_decode_wave(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, - kaldi::OnlineSilenceWeighting &silence_weighting, - kaldi::SingleUtteranceNnet3Decoder &decoder, - kaldi::SubVector &wave_part, - std::vector> &delta_weights, - const kaldi::BaseFloat &samp_freq) const { - - feature_pipeline.AcceptWaveform(samp_freq, wave_part); - - if (silence_weighting.Active() && feature_pipeline.IvectorFeature() != NULL) { - silence_weighting.ComputeCurrentTraceback(decoder.Decoder()); - silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(), - &delta_weights); - feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights); - } - decoder.AdvanceDecoding(); -} - -void Decoder::decode_stream_wav_chunk(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, - kaldi::OnlineSilenceWeighting &silence_weighting, - kaldi::SingleUtteranceNnet3Decoder &decoder, - std::istream &wav_stream) const { - - kaldi::WaveData wave_data; - wave_data.Read(wav_stream); - - const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); - - // get the data for channel zero (if the signal is not mono, we only - // take the first channel). - kaldi::SubVector wave_part(wave_data.Data(), 0); - std::vector> delta_weights; - _decode_wave(feature_pipeline, silence_weighting, decoder, wave_part, delta_weights, samp_freq); -} - -void Decoder::decode_stream_raw_wav_chunk(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, - kaldi::OnlineSilenceWeighting &silence_weighting, - kaldi::SingleUtteranceNnet3Decoder &decoder, - std::istream &wav_stream, - const kaldi::BaseFloat& samp_freq, - const size_t &data_bytes) const { - - kaldi::Matrix wave_matrix; - read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); - - // get the data for channel zero (if the signal is not mono, we only - // take the first channel). - kaldi::SubVector wave_part(wave_matrix, 0); - std::vector> delta_weights; - _decode_wave(feature_pipeline, silence_weighting, decoder, wave_part, delta_weights, samp_freq); -} - -void Decoder::decode_wav_audio(std::istream &wav_stream, - const size_t &n_best, - utterance_results_t &results, - const bool &word_level, - const kaldi::BaseFloat &chunk_size) const { - // decoder state variables need to be statically initialized - kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(feature_info_->ivector_extractor_info); - kaldi::OnlineNnet2FeaturePipeline feature_pipeline(*feature_info_); - feature_pipeline.SetAdaptationState(adaptation_state); - - kaldi::OnlineSilenceWeighting silence_weighting(trans_model_, feature_info_->silence_weighting_config, - decodable_opts_.frame_subsampling_factor); - kaldi::SingleUtteranceNnet3Decoder decoder(lattice_faster_decoder_config_, - trans_model_, *decodable_info_.get(), *decode_fst_, - &feature_pipeline); - - kaldi::WaveData wave_data; - wave_data.Read(wav_stream); - - // get the data for channel zero (if the signal is not mono, we only - // take the first channel). - kaldi::SubVector data(wave_data.Data(), 0); - const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); - - int32 chunk_length; - if (chunk_size > 0) { - chunk_length = int32(samp_freq * chunk_size); - if (chunk_length == 0) - chunk_length = 1; - } else { - chunk_length = std::numeric_limits::max(); - } - - int32 samp_offset = 0; - std::vector> delta_weights; - - while (samp_offset < data.Dim()) { - int32 samp_remaining = data.Dim() - samp_offset; - int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; - - kaldi::SubVector wave_part(data, samp_offset, num_samp); - _decode_wave(feature_pipeline, silence_weighting, decoder, wave_part, delta_weights, samp_freq); - - samp_offset += num_samp; - } - - decode_stream_final(feature_pipeline, decoder, n_best, results, word_level); -} - -void Decoder::decode_raw_wav_audio(std::istream &wav_stream, - const kaldi::BaseFloat &samp_freq, - const size_t &data_bytes, - const size_t &n_best, - utterance_results_t &results, - const bool &word_level, - const kaldi::BaseFloat &chunk_size) const { - // decoder state variables need to be statically initialized - kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(feature_info_->ivector_extractor_info); - kaldi::OnlineNnet2FeaturePipeline feature_pipeline(*feature_info_); - feature_pipeline.SetAdaptationState(adaptation_state); - - kaldi::OnlineSilenceWeighting silence_weighting(trans_model_, feature_info_->silence_weighting_config, - decodable_opts_.frame_subsampling_factor); - kaldi::SingleUtteranceNnet3Decoder decoder(lattice_faster_decoder_config_, - trans_model_, *decodable_info_.get(), *decode_fst_, - &feature_pipeline); - - kaldi::Matrix wave_matrix; - read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); - - // get the data for channel zero (if the signal is not mono, we only - // take the first channel). - kaldi::SubVector data(wave_matrix, 0); - - int32 chunk_length; - if (chunk_size > 0) { - chunk_length = int32(samp_freq * chunk_size); - if (chunk_length == 0) - chunk_length = 1; - } else { - chunk_length = std::numeric_limits::max(); - } - - int32 samp_offset = 0; - std::vector> delta_weights; - - while (samp_offset < data.Dim()) { - int32 samp_remaining = data.Dim() - samp_offset; - int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; - - kaldi::SubVector wave_part(data, samp_offset, num_samp); - _decode_wave(feature_pipeline, silence_weighting, decoder, wave_part, delta_weights, samp_freq); - - samp_offset += num_samp; - } - - decode_stream_final(feature_pipeline, decoder, n_best, results, word_level); -} - -void Decoder::decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, - kaldi::SingleUtteranceNnet3Decoder &decoder, - const std::size_t &n_best, - utterance_results_t &results, - const bool &word_level, - const bool &bidi_streaming) const { - - if (!bidi_streaming) { - feature_pipeline.InputFinished(); - decoder.FinalizeDecoding(); - } - - if (decoder.NumFramesDecoded() == 0) { - KALDI_WARN << "audio may be empty :: decoded no frames"; - return; - } - - kaldi::CompactLattice clat; - try { - decoder.GetLattice(true, &clat); - _find_alternatives(clat, n_best, results, word_level); - } catch (std::exception &e) { - KALDI_ERR << "unexpected error during decoding lattice :: " << e.what(); - } -} - // Factory for creating decoders with shared decoding graph and model parameters // Caches the graph and params to be able to produce decoders on demand. class DecoderFactory final { - - private: - const std::unique_ptr> decode_fst_; - - const std::string model_dir_; - const kaldi::BaseFloat beam_; - const std::size_t min_active_; - const std::size_t max_active_; - const kaldi::BaseFloat lattice_beam_; - const kaldi::BaseFloat acoustic_scale_; - const std::size_t frame_subsampling_factor_; - public: - explicit DecoderFactory(const std::string &, - const kaldi::BaseFloat &, - const std::size_t &, - const std::size_t &, - const kaldi::BaseFloat &, - const kaldi::BaseFloat &, - const std::size_t &) noexcept; + ModelSpec model_spec; + + explicit DecoderFactory(const ModelSpec &model_spec); - inline Decoder *produce() const noexcept; + inline Decoder *produce() const; // friendly alias for the producer method - inline Decoder *operator()() const noexcept; + inline Decoder *operator()() const; }; -DecoderFactory::DecoderFactory(const std::string &model_dir, - const kaldi::BaseFloat &beam, - const std::size_t &min_active, - const std::size_t &max_active, - const kaldi::BaseFloat &lattice_beam, - const kaldi::BaseFloat &acoustic_scale, - const std::size_t &frame_subsampling_factor) noexcept - : decode_fst_(fst::ReadFstKaldiGeneric(join_path(model_dir, "HCLG.fst"))), - model_dir_ (model_dir), - beam_(beam), min_active_(min_active), max_active_(max_active), lattice_beam_(lattice_beam), - acoustic_scale_(acoustic_scale), frame_subsampling_factor_(frame_subsampling_factor) {} - -inline Decoder *DecoderFactory::produce() const noexcept { - return new Decoder(beam_, min_active_, max_active_, lattice_beam_, - acoustic_scale_, frame_subsampling_factor_, - model_dir_, decode_fst_.get()); +DecoderFactory::DecoderFactory(const ModelSpec &model_spec) : model_spec(model_spec) { +} + +inline Decoder *DecoderFactory::produce() const { + return new Decoder(model_spec.beam, + model_spec.min_active, + model_spec.max_active, + model_spec.lattice_beam, + model_spec.acoustic_scale, + model_spec.frame_subsampling_factor, + model_spec.path); } -inline Decoder *DecoderFactory::operator()() const noexcept { +inline Decoder *DecoderFactory::operator()() const { return produce(); } @@ -638,7 +611,7 @@ class DecoderQueue final { DecoderQueue &operator=(const DecoderQueue &) = delete; // disable assignment - ~DecoderQueue() noexcept; + ~DecoderQueue(); // friendly alias for `pop` inline Decoder *acquire(); @@ -655,13 +628,7 @@ DecoderQueue::DecoderQueue(const ModelSpec &model_spec) { // LOG MODELS LOAD TIME --> START start_time = std::chrono::system_clock::now(); } - decoder_factory_ = std::unique_ptr(new DecoderFactory(model_spec.path, - model_spec.beam, - model_spec.min_active, - model_spec.max_active, - model_spec.lattice_beam, - model_spec.acoustic_scale, - model_spec.frame_subsampling_factor)); + decoder_factory_ = std::unique_ptr(new DecoderFactory(model_spec)); for (size_t i = 0; i < model_spec.n_decoders; i++) { queue_.push(decoder_factory_->produce()); } @@ -674,7 +641,7 @@ DecoderQueue::DecoderQueue(const ModelSpec &model_spec) { } } -DecoderQueue::~DecoderQueue() noexcept { +DecoderQueue::~DecoderQueue() { while (!queue_.empty()) { auto decoder = queue_.front(); queue_.pop(); diff --git a/src/server.hpp b/src/server.hpp index 7421cda..4119200 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -26,7 +26,9 @@ #include "kaldi_serve.grpc.pb.h" -void add_alternatives_to_response(const utterance_results_t &results, kaldi_serve::RecognizeResponse *response, const kaldi_serve::RecognitionConfig &config) noexcept { +void add_alternatives_to_response(const utterance_results_t &results, + kaldi_serve::RecognizeResponse *response, + const kaldi_serve::RecognitionConfig &config) noexcept { kaldi_serve::SpeechRecognitionResult *sr_result = response->add_results(); kaldi_serve::SpeechRecognitionAlternative *alternative; @@ -53,6 +55,7 @@ void add_alternatives_to_response(const utterance_results_t &results, kaldi_serv } } + // KaldiServeImpl :: // Defines the core server logic and request/response handlers. // Keeps `Decoder` instances cached in a thread-safe @@ -122,6 +125,7 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, // - Waits here until lock on queue is attained. // - Each new audio stream gets separate decoder object. Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); + decoder_->start_decoding(); std::chrono::system_clock::time_point start_time; if (DEBUG) { @@ -131,14 +135,12 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, kaldi_serve::RecognitionAudio audio = request->audio(); std::stringstream input_stream(audio.content()); - utterance_results_t k_results_; - // decode speech signals in chunks try { if (config.raw()) { - decoder_->decode_raw_wav_audio(input_stream, sample_rate_hertz, config.data_bytes(), n_best, k_results_, config.word_level()); + decoder_->decode_raw_wav_audio(input_stream, sample_rate_hertz, config.data_bytes()); } else { - decoder_->decode_wav_audio(input_stream, n_best, k_results_, config.word_level()); + decoder_->decode_wav_audio(input_stream); } } catch (kaldi::KaldiFatalError &e) { decoder_queue_map_[model_id]->release(decoder_); @@ -149,11 +151,15 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); } + utterance_results_t k_results_; + decoder_->get_decoded_results(n_best, k_results_, config.word_level()); + add_alternatives_to_response(k_results_, response, config); // Decoder Release :: // - Releases the lock on the decoder and pushes back into queue. // - Notifies another request handler thread of availability. + decoder_->free_decoder(); decoder_queue_map_[model_id]->release(decoder_); if (DEBUG) { @@ -191,17 +197,7 @@ grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const conte // - Waits here until lock on queue is attained. // - Each new audio stream gets separate decoder object. Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); - - // decoder state variables need to be statically initialized - kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(decoder_->feature_info_->ivector_extractor_info); - kaldi::OnlineNnet2FeaturePipeline feature_pipeline(*decoder_->feature_info_); - feature_pipeline.SetAdaptationState(adaptation_state); - - kaldi::OnlineSilenceWeighting silence_weighting(decoder_->trans_model_, decoder_->feature_info_->silence_weighting_config, - decoder_->decodable_opts_.frame_subsampling_factor); - kaldi::SingleUtteranceNnet3Decoder decoder(decoder_->lattice_faster_decoder_config_, - decoder_->trans_model_, *decoder_->decodable_info_.get(), *decoder_->decode_fst_, - &feature_pipeline); + decoder_->start_decoding(); std::chrono::system_clock::time_point start_time, start_time_req; if (DEBUG) { @@ -242,9 +238,9 @@ grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const conte // Assuming: audio stream has already been chunked into desired length try { if (config.raw()) { - decoder_->decode_stream_raw_wav_chunk(feature_pipeline, silence_weighting, decoder, input_stream_chunk, sample_rate_hertz, config.data_bytes()); + decoder_->decode_stream_raw_wav_chunk(input_stream_chunk, sample_rate_hertz, config.data_bytes()); } else { - decoder_->decode_stream_wav_chunk(feature_pipeline, silence_weighting, decoder, input_stream_chunk); + decoder_->decode_stream_wav_chunk(input_stream_chunk); } } catch (kaldi::KaldiFatalError &e) { decoder_queue_map_[model_id]->release(decoder_); @@ -276,13 +272,14 @@ grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const conte } utterance_results_t k_results_; - decoder_->decode_stream_final(feature_pipeline, decoder, n_best, k_results_, config.word_level()); + decoder_->get_decoded_results(n_best, k_results_, config.word_level()); add_alternatives_to_response(k_results_, response, config); // Decoder Release :: // - Releases the lock on the decoder and pushes back into queue. // - Notifies another request handler thread of availability. + decoder_->free_decoder(); decoder_queue_map_[model_id]->release(decoder_); if (DEBUG) { @@ -322,17 +319,7 @@ grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const c // - Waits here until lock on queue is attained. // - Each new audio stream gets separate decoder object. Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); - - // decoder state variables need to be statically initialized - kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(decoder_->feature_info_->ivector_extractor_info); - kaldi::OnlineNnet2FeaturePipeline feature_pipeline(*decoder_->feature_info_); - feature_pipeline.SetAdaptationState(adaptation_state); - - kaldi::OnlineSilenceWeighting silence_weighting(decoder_->trans_model_, decoder_->feature_info_->silence_weighting_config, - decoder_->decodable_opts_.frame_subsampling_factor); - kaldi::SingleUtteranceNnet3Decoder decoder(decoder_->lattice_faster_decoder_config_, - decoder_->trans_model_, *decoder_->decodable_info_.get(), *decoder_->decode_fst_, - &feature_pipeline); + decoder_->start_decoding(); std::chrono::system_clock::time_point start_time, start_time_req; if (DEBUG) { @@ -372,13 +359,13 @@ grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const c // Assuming: audio stream has already been chunked into desired length try { if (config.raw()) { - decoder_->decode_stream_raw_wav_chunk(feature_pipeline, silence_weighting, decoder, input_stream_chunk, sample_rate_hertz, config.data_bytes()); + decoder_->decode_stream_raw_wav_chunk(input_stream_chunk, sample_rate_hertz, config.data_bytes()); } else { - decoder_->decode_stream_wav_chunk(feature_pipeline, silence_weighting, decoder, input_stream_chunk); + decoder_->decode_stream_wav_chunk(input_stream_chunk); } utterance_results_t k_results_; - decoder_->decode_stream_final(feature_pipeline, decoder, n_best, k_results_, config.word_level(), true); + decoder_->get_decoded_results(n_best, k_results_, config.word_level(), true); kaldi_serve::RecognizeResponse response_; add_alternatives_to_response(k_results_, &response_, config); @@ -415,7 +402,7 @@ grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const c } utterance_results_t k_results_; - decoder_->decode_stream_final(feature_pipeline, decoder, n_best, k_results_, config.word_level()); + decoder_->get_decoded_results(n_best, k_results_, config.word_level()); kaldi_serve::RecognizeResponse response_; add_alternatives_to_response(k_results_, &response_, config); @@ -426,6 +413,7 @@ grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const c // Decoder Release :: // - Releases the lock on the decoder and pushes back into queue. // - Notifies another request handler thread of availability. + decoder_->free_decoder(); decoder_queue_map_[model_id]->release(decoder_); if (DEBUG) { @@ -441,6 +429,7 @@ grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const c return grpc::Status::OK; } + // Runs the Server with the Kaldi Service void run_server(const std::vector &model_specs) { KaldiServeImpl service(model_specs); @@ -457,6 +446,7 @@ void run_server(const std::vector &model_specs) { server->Wait(); } + /** NOTES: ------