From 4bf8492577cfe3c228f6fd16d1a1b523ddafe2bc Mon Sep 17 00:00:00 2001 From: Simon-Berger Date: Wed, 29 Mar 2023 15:47:05 +0200 Subject: [PATCH 1/3] Add smaller changes --- Modules.make | 10 +++- src/Bliss/Lexicon.hh | 8 +-- src/Flf/SegmentwiseSpeechProcessor.cc | 7 ++- src/Flf/SegmentwiseSpeechProcessor.hh | 2 +- src/Lattice/Lattice.cc | 27 ++++++++- src/Lattice/Lattice.hh | 7 ++- src/Nn/LabelScorer.cc | 20 +++++++ src/Nn/LabelScorer.hh | 34 +++++++++++ src/Nn/Makefile | 11 ++++ src/Nn/Module.cc | 17 ++++++ src/Nn/Module.hh | 4 ++ src/Nn/TFLabelScorer.cc | 20 +++++++ src/Nn/TFLabelScorer.hh | 29 ++++++++++ .../AdvancedTreeSearch/AdvancedTreeSearch.cc | 2 +- src/Search/GenericSeq2SeqTreeSearch/Makefile | 36 ++++++++++++ .../Seq2SeqAligner.cc | 20 +++++++ .../Seq2SeqAligner.hh | 33 +++++++++++ .../Seq2SeqTreeSearch.cc | 51 +++++++++++++++++ .../Seq2SeqTreeSearch.hh | 56 +++++++++++++++++++ src/Search/Histogram.hh | 4 +- src/Search/Makefile | 5 ++ src/Search/Module.cc | 12 ++++ src/Search/Module.hh | 1 + src/Speech/AlignmentNode.cc | 19 +++++++ src/Speech/AlignmentNode.hh | 22 ++++++++ src/Speech/Makefile | 3 + src/Speech/ModelCombination.cc | 6 ++ src/Speech/ModelCombination.hh | 9 ++- src/Speech/Module.cc | 4 ++ src/Speech/Recognizer.cc | 1 + src/Tensorflow/Tensor.cc | 10 ++++ src/Test/Makefile | 3 + src/Tools/Archiver/Makefile | 3 + src/Tools/NnTrainer/Makefile | 3 + 34 files changed, 482 insertions(+), 17 deletions(-) create mode 100644 src/Nn/LabelScorer.cc create mode 100644 src/Nn/LabelScorer.hh create mode 100644 src/Nn/TFLabelScorer.cc create mode 100644 src/Nn/TFLabelScorer.hh create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Makefile create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh diff --git a/Modules.make b/Modules.make index 604924046..765fd3849 100644 --- a/Modules.make +++ b/Modules.make @@ -69,13 +69,14 @@ MODULES += MODULE_THEANO_INTERFACE MODULES += MODULE_PYTHON # ****** OpenFst ****** - MODULES += MODULE_OPENFST +MODULES += MODULE_OPENFST # ****** Search ****** - MODULES += MODULE_SEARCH_MBR - MODULES += MODULE_SEARCH_WFST +MODULES += MODULE_SEARCH_MBR +MODULES += MODULE_SEARCH_WFST MODULES += MODULE_SEARCH_LINEAR MODULES += MODULE_ADVANCED_TREE_SEARCH +MODULES += MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH # ****** Signal ****** MODULES += MODULE_SIGNAL_GAMMATONE @@ -151,3 +152,6 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH LIBS_SEARCH += src/Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +LIBS_SEARCH += src/Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif diff --git a/src/Bliss/Lexicon.hh b/src/Bliss/Lexicon.hh index 215bdc406..47563327a 100644 --- a/src/Bliss/Lexicon.hh +++ b/src/Bliss/Lexicon.hh @@ -642,7 +642,7 @@ public: LemmaIterator(lemmas_.end())); } -#ifdef OBSOLETE +#if defined(OBSOLETE) || defined(MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH) /** * Find a lemma via ID number. * Remember that lemma IDs must not be associated with a @@ -657,13 +657,13 @@ public: */ const Lemma* lemma(Lemma::Id id) const { const Lemma* result = 0; - if (id < lemmas_.size()) { - result = lemmas_[id]; + if (id >=0 && (u32)id < lemmas_.size()) { + result = (const Lemma*)(lemmas_[id]); ensure(result->id() == id); } return result; } -#endif // OBSOLETE +#endif /** * Find a lemma via ID string. This name can either be given diff --git a/src/Flf/SegmentwiseSpeechProcessor.cc b/src/Flf/SegmentwiseSpeechProcessor.cc index edf478d06..b7dc35b90 100644 --- a/src/Flf/SegmentwiseSpeechProcessor.cc +++ b/src/Flf/SegmentwiseSpeechProcessor.cc @@ -20,8 +20,11 @@ namespace Flf { // ------------------------------------------------------------------------- -AcousticModelRef getAm(const Core::Configuration& config) { - return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us())); +AcousticModelRef getAm(const Core::Configuration& config, bool useMixture) { + if (useMixture) + return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us())); + else + return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()), Am::AcousticModel::noEmissions); } ScaledLanguageModelRef getLm(const Core::Configuration& config) { diff --git a/src/Flf/SegmentwiseSpeechProcessor.hh b/src/Flf/SegmentwiseSpeechProcessor.hh index 759f6bddb..8c9f271c1 100644 --- a/src/Flf/SegmentwiseSpeechProcessor.hh +++ b/src/Flf/SegmentwiseSpeechProcessor.hh @@ -38,7 +38,7 @@ typedef Core::Ref LanguageModelRef; typedef Core::Ref ScaledLanguageModelRef; typedef Core::Ref ModelCombinationRef; -AcousticModelRef getAm(const Core::Configuration& config); +AcousticModelRef getAm(const Core::Configuration& config, bool useMixture=true); ScaledLanguageModelRef getLm(const Core::Configuration& config); ModelCombinationRef getModelCombination(const Core::Configuration& config, AcousticModelRef acousticModel, ScaledLanguageModelRef languageModel = ScaledLanguageModelRef()); diff --git a/src/Lattice/Lattice.cc b/src/Lattice/Lattice.cc index 5940c368c..6a33fcf8a 100644 --- a/src/Lattice/Lattice.cc +++ b/src/Lattice/Lattice.cc @@ -193,18 +193,27 @@ Speech::TimeframeIndex WordLattice::maximumTime() const { return d.getMaximumTime(); } -StandardWordLattice::StandardWordLattice(Core::Ref lexicon) { +StandardWordLattice::StandardWordLattice(Core::Ref lexicon, + bool useLemmaAlphabet) { parts_.addChoice("acoustic", 0); parts_.addChoice("lm", 1); acoustic_ = Core::ref(new Fsa::StaticAutomaton); acoustic_->setType(Fsa::TypeAcceptor); - acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + if (useLemmaAlphabet) { + acoustic_->setInputAlphabet(lexicon->lemmaAlphabet()); + } else { + acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + } acoustic_->setSemiring(Fsa::TropicalSemiring); lm_ = Core::ref(new Fsa::StaticAutomaton); lm_->setType(Fsa::TypeAcceptor); - lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + if (useLemmaAlphabet) { + lm_->setInputAlphabet(lexicon->lemmaAlphabet()); + } else { + lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + } lm_->setSemiring(Fsa::TropicalSemiring); fsas_.push_back(acoustic_); @@ -239,6 +248,18 @@ void StandardWordLattice::newArc( lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon); } +void StandardWordLattice::newArc( + Fsa::State* source, + Fsa::State* target, + const Bliss::Lemma* lemma, + Speech::Score acoustic, Speech::Score lm) { + source->newArc(target->id(), Fsa::Weight(acoustic), + (lemma) ? lemma->id() : Fsa::Epsilon); + + lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), + (lemma) ? lemma->id() : Fsa::Epsilon); +} + void StandardWordLattice::addAcyclicProperty() { if (Fsa::isAcyclic(acoustic_)) { acoustic_->setProperties(Fsa::PropertyAcyclic, Fsa::PropertyAcyclic); diff --git a/src/Lattice/Lattice.hh b/src/Lattice/Lattice.hh index d0a386349..11d755f65 100644 --- a/src/Lattice/Lattice.hh +++ b/src/Lattice/Lattice.hh @@ -273,7 +273,7 @@ private: Fsa::State * initialState_, *finalState_; public: - StandardWordLattice(Bliss::LexiconRef); + StandardWordLattice(Bliss::LexiconRef, bool useLemmaAlphabet = false); Fsa::State* newState(); Fsa::State* initialState() { @@ -288,6 +288,11 @@ public: const Bliss::LemmaPronunciation*, Speech::Score acoustic, Speech::Score lm); + void newArc(Fsa::State *source, + Fsa::State *target, + const Bliss::Lemma*, + Speech::Score acoustic, Speech::Score lm); + void addAcyclicProperty(); }; diff --git a/src/Nn/LabelScorer.cc b/src/Nn/LabelScorer.cc new file mode 100644 index 000000000..6fc1f0721 --- /dev/null +++ b/src/Nn/LabelScorer.cc @@ -0,0 +1,20 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#include "LabelScorer.hh" + +using namespace Nn; diff --git a/src/Nn/LabelScorer.hh b/src/Nn/LabelScorer.hh new file mode 100644 index 000000000..5f8b37821 --- /dev/null +++ b/src/Nn/LabelScorer.hh @@ -0,0 +1,34 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef LABEL_SCORER_HH +#define LABEL_SCORER_HH + +#include +#include + +namespace Nn { + + +// base class of models for label scoring (basic supports except scoring) +class LabelScorer : public virtual Core::Component, + public Core::ReferenceCounted { +}; + +} // namespace Nn + +#endif diff --git a/src/Nn/Makefile b/src/Nn/Makefile index c141b9387..a3723a974 100644 --- a/src/Nn/Makefile +++ b/src/Nn/Makefile @@ -77,6 +77,17 @@ ifdef MODULE_PYTHON LIBSPRINTNN_O += $(OBJDIR)/PythonLayer.o endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + LIBSPRINTNN_O += $(OBJDIR)/LabelScorer.o + ifdef MODULE_TENSORFLOW + LIBSPRINTNN_O += $(OBJDIR)/TFLabelScorer.o + + CXXFLAGS += $(TF_CXXFLAGS) + LDFLAGS += $(TF_LDFLAGS) + CHECK_O += ../Tensorflow/libSprintTensorflow.$(a) + endif +endif + # ----------------------------------------------------------------------------- all: $(TARGETS) diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index d2bcc08f7..efb1485d5 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -34,6 +34,13 @@ #include "PythonFeatureScorer.hh" #endif +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include "LabelScorer.hh" +#ifdef MODULE_TENSORFLOW +#include "TFLabelScorer.hh" +#endif +#endif + using namespace Nn; Module_::Module_() @@ -80,3 +87,13 @@ Core::FormatSet& Module_::formats() { } return *formats_; } + + +Core::Ref Module_::createLabelScorer(const Core::Configuration& config) const { +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + LabelScorer* labelScorer = nullptr; + return Core::ref(labelScorer); +#else + Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); +#endif +} diff --git a/src/Nn/Module.hh b/src/Nn/Module.hh index bb923ba2b..304d27a61 100644 --- a/src/Nn/Module.hh +++ b/src/Nn/Module.hh @@ -26,6 +26,8 @@ class FormatSet; namespace Nn { +class LabelScorer; + class Module_ { private: Core::FormatSet* formats_; @@ -47,6 +49,8 @@ public: /** Set of file format class. */ Core::FormatSet& formats(); + + Core::Ref createLabelScorer(const Core::Configuration& config) const; }; typedef Core::SingletonHolder Module; diff --git a/src/Nn/TFLabelScorer.cc b/src/Nn/TFLabelScorer.cc new file mode 100644 index 000000000..0c07418f5 --- /dev/null +++ b/src/Nn/TFLabelScorer.cc @@ -0,0 +1,20 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#include "TFLabelScorer.hh" + +using namespace Nn; diff --git a/src/Nn/TFLabelScorer.hh b/src/Nn/TFLabelScorer.hh new file mode 100644 index 000000000..bf53b1140 --- /dev/null +++ b/src/Nn/TFLabelScorer.hh @@ -0,0 +1,29 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef TF_LABEL_SCORER_HH +#define TF_LABEL_SCORER_HH + +#include "LabelScorer.hh" + +namespace Nn { + + +} // namesapce + +#endif + diff --git a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc index 60fe42e6b..3e3a17142 100644 --- a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc +++ b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc @@ -471,7 +471,7 @@ Core::Ref AdvancedTreeSearchManager::buildLatticeForTrace( wordBoundaries->set((Fsa::StateId) /* WARNING: this looses a few bits */ (long)result->initialState(), Lattice::WordBoundary(0)); Fsa::State* finalState = result->newState(); wordBoundaries->set(finalState->id(), Lattice::WordBoundary(time_)); - result->newArc(result->initialState(), finalState, 0, 0, 0); + result->newArc(result->initialState(), finalState, (const Bliss::LemmaPronunciation*)0, 0, 0); result->setWordBoundaries(wordBoundaries); result->addAcyclicProperty(); return Core::ref(new Lattice::WordLatticeAdaptor(result)); diff --git a/src/Search/GenericSeq2SeqTreeSearch/Makefile b/src/Search/GenericSeq2SeqTreeSearch/Makefile new file mode 100644 index 000000000..5f281bf53 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Makefile @@ -0,0 +1,36 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintGenericSeq2SeqTreeSearch.$(a) + +LIBSPRINTSEQ2SEQTREESEARCH_O = $(OBJDIR)/Seq2SeqTreeSearch.o \ + $(OBJDIR)/Seq2SeqAligner.o + +ifeq ($(OS),darwin) +CCFLAGS += -fexceptions +endif + +# These flags make the compilation slow, but are required to inline some critical functions in SearchSpace (copied from AdvancedTreeSearch) +CCFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400 +CXXFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400 + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintGenericSeq2SeqTreeSearch.$(a): $(LIBSPRINTSEQ2SEQTREESEARCH_O) + $(MAKELIB) $@ $^ + +check$(exe): $(CHECK_O) + $(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS) + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTSEQ2SEQTREESEARCH_O:.o=.d) +include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O))) diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc new file mode 100644 index 000000000..de20eca30 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc @@ -0,0 +1,20 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#include "Seq2SeqAligner.hh" + +using namespace Search; diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh new file mode 100644 index 000000000..da168c8f9 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh @@ -0,0 +1,33 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef SEQ2SEQ_ALIGNER_HH +#define SEQ2SEQ_ALIGNER_HH + +#include + +namespace Search { + + +// Integrated alignment interface and search space +// So far: Viterbi only +class Seq2SeqAligner : public Core::Component { +}; + +} // namespace + +#endif // SEQ2SEQ_ALIGNER_HH diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc new file mode 100644 index 000000000..355dd987a --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc @@ -0,0 +1,51 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#include "Seq2SeqTreeSearch.hh" + +using namespace Search; + + +Seq2SeqTreeSearchManager::Seq2SeqTreeSearchManager(const Core::Configuration& c) + : Core::Component(c), + SearchAlgorithm(c) { +} + + +bool Seq2SeqTreeSearchManager::setModelCombination(const Speech::ModelCombination& modelCombination) { +} + + +void Seq2SeqTreeSearchManager::setGrammar(Fsa::ConstAutomatonRef g) { +} + +void Seq2SeqTreeSearchManager::resetStatistics() { +} + +void Seq2SeqTreeSearchManager::logStatistics() const { +} + +void Seq2SeqTreeSearchManager::restart() { +} + + +void Seq2SeqTreeSearchManager::getCurrentBestSentence(Traceback &result) const { +} + +Core::Ref Seq2SeqTreeSearchManager::getCurrentWordLattice() const { +} + diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh new file mode 100644 index 000000000..45112ce37 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh @@ -0,0 +1,56 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef SEQ2SEQ_TREE_SEARCH_HH +#define SEQ2SEQ_TREE_SEARCH_HH + +#include +#include + +// search manager: interface between search space and high level recognizer +// - manage step-wise decoding: expansion, pruning, recombination, etc. +// - results pulling (traceback) + +namespace Search { + + +class Seq2SeqTreeSearchManager : public SearchAlgorithm { + + public: + Seq2SeqTreeSearchManager(const Core::Configuration &); + + // ---- from SearchAlgorithm (overwrite required) ---- + virtual bool setModelCombination(const Speech::ModelCombination& modelCombination); + virtual void setGrammar(Fsa::ConstAutomatonRef); + virtual void restart(); + + // replaced by decode and decodeNext + virtual void feed(const Mm::FeatureScorer::Scorer&) {} + + // TODO partial result not supported yet + + virtual void getCurrentBestSentence(Traceback &result) const; + virtual Core::Ref getCurrentWordLattice() const; + + virtual void resetStatistics(); + virtual void logStatistics() const; + }; + +} // namespace + +#endif + diff --git a/src/Search/Histogram.hh b/src/Search/Histogram.hh index 0bcaa2122..02007dabd 100644 --- a/src/Search/Histogram.hh +++ b/src/Search/Histogram.hh @@ -64,7 +64,7 @@ public: void operator+=(Score s) { bins_[bin(s)] += 1; } - Score quantile(Count nn) const { + Score quantile(Count nn, bool safe=false) const { Bin b = 0; for (s32 n = nn; b < bins_.size(); ++b) { // n must be signed! n -= bins_[b]; @@ -72,6 +72,8 @@ public: break; } verify(b <= bins_.size()); + if (b == 0 && safe) + b = 1; // at least one bin, otherwise only single best left Score result = Score(b) / scale_ + lower_; ensure(lower_ <= result); ensure(result < upper_ + 2.0 / scale_); diff --git a/src/Search/Makefile b/src/Search/Makefile index abeb7774d..11514593b 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -41,6 +41,9 @@ endif ifdef MODULE_SEARCH_LINEAR LIBSPRINTSEARCH_O += $(OBJDIR)/LinearSearch.o endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +SUBDIRS += GenericSeq2SeqTreeSearch +endif # ----------------------------------------------------------------------------- @@ -62,6 +65,8 @@ Wfst: AdvancedTreeSearch: $(MAKE) -C $@ libSprintAdvancedTreeSearch.$(a) +GenericSeq2SeqTreeSearch: + $(MAKE) -C $@ libSprintGenericSeq2SeqTreeSearch.$(a) include $(TOPDIR)/Rules.make diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 3dfb12aca..2dd703b21 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -26,6 +26,9 @@ #ifdef MODULE_ADVANCED_TREE_SEARCH #include "AdvancedTreeSearch/AdvancedTreeSearch.hh" #endif +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include "GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh" +#endif using namespace Search; @@ -60,6 +63,15 @@ SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configur Core::Application::us()->criticalError("Module MODULE_SEARCH_LINEAR not available!"); #endif break; + + case GenericSeq2SeqTreeSearchType: +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + recognizer = new Search::Seq2SeqTreeSearchManager(config); +#else + Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); +#endif + break; + default: Core::Application::us()->criticalError("unknown recognizer type: %d", type); break; diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 04b3c33fb..573a3a7d6 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -28,6 +28,7 @@ enum SearchType { AdvancedTreeSearch, LinearSearchType, ExpandingFsaSearchType, + GenericSeq2SeqTreeSearchType }; class Module_ { diff --git a/src/Speech/AlignmentNode.cc b/src/Speech/AlignmentNode.cc index 4d0a0dbe5..318e63f74 100644 --- a/src/Speech/AlignmentNode.cc +++ b/src/Speech/AlignmentNode.cc @@ -374,6 +374,25 @@ void AlignmentNode::logTraceback(Lattice::ConstWordLatticeRef wordLattice) const tracebackChannel_ << Core::XmlClose("traceback"); } + +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +/** Seq2SeqAlignmentNode +*/ + +Seq2SeqAlignmentNode::Seq2SeqAlignmentNode(const Core::Configuration& c) : + Core::Component(c), + Precursor(c) { +} + + +void Seq2SeqAlignmentNode::createModel() { +} + +bool Seq2SeqAlignmentNode::work(Flow::PortId p) { +} +#endif + + /** AlignmentDumpNode */ const Core::ParameterString AlignmentDumpNode::paramFilename( diff --git a/src/Speech/AlignmentNode.hh b/src/Speech/AlignmentNode.hh index 201a86ab2..d3bd47721 100644 --- a/src/Speech/AlignmentNode.hh +++ b/src/Speech/AlignmentNode.hh @@ -18,6 +18,9 @@ #include #include #include +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include +#endif #include "Alignment.hh" #include "ModelCombination.hh" @@ -115,6 +118,25 @@ public: virtual bool work(Flow::PortId); }; +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +/** Seq2Seq AlignmentNode */ +class Seq2SeqAlignmentNode : public AlignmentBaseNode { + typedef AlignmentBaseNode Precursor; + +public: + static std::string filterName() { + return "speech-seq2seq-alignment"; + } + +public: + Seq2SeqAlignmentNode(const Core::Configuration&); + virtual bool work(Flow::PortId); + +protected: + void createModel(); +}; +#endif + /** Dumps alignments in a plain text format */ class AlignmentDumpNode : public Flow::Node { typedef Flow::Node Precursor; diff --git a/src/Speech/Makefile b/src/Speech/Makefile index 4b5ab3528..60bee95d3 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -143,6 +143,9 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH CHECK_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +CHECK_O += ../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON CHECK_O += ../Python/libSprintPython.$(a) endif diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index 7ab089d27..e7d8ca106 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -15,6 +15,8 @@ #include "ModelCombination.hh" #include #include +#include + using namespace Speech; @@ -112,3 +114,7 @@ void ModelCombination::getDependencies(Core::DependencySet& dependencies) const dependencies.add(name(), d); Mc::Component::getDependencies(dependencies); } + +void ModelCombination::createLabelScorer() { + labelScorer_ = Nn::Module::instance().createLabelScorer(select("label-scorer")); +} diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index f00e4a249..2400e9279 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -21,10 +21,12 @@ #include #include #include +#include + namespace Speech { -/** Combination of a lexicon, an acoustic model, and a language model. +/** Combination of a lexicon, an acoustic model or label scorer, and a language model. * It supports creation and initialization of these three mutually dependent objects. * * Usage: @@ -49,6 +51,7 @@ protected: Mm::Score pronunciationScale_; Core::Ref acousticModel_; Core::Ref languageModel_; + Core::Ref labelScorer_; private: void setPronunciationScale(Mm::Score scale) { @@ -83,6 +86,10 @@ public: return languageModel_; } void setLanguageModel(Core::Ref); + + void createLabelScorer(); + void setLabelScorer(Core::Ref& ls) { labelScorer_ = ls; } + Core::Ref labelScorer() const { return labelScorer_; } }; typedef Core::Ref ModelCombinationRef; diff --git a/src/Speech/Module.cc b/src/Speech/Module.cc index fa31fa230..e55993655 100644 --- a/src/Speech/Module.cc +++ b/src/Speech/Module.cc @@ -65,6 +65,10 @@ Module_::Module_() { registry.registerFilter(); registry.registerDatatype>(); +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + registry.registerFilter(); +#endif + #ifdef MODULE_SPEECH_ALIGNMENT_FLOW_NODES registry.registerFilter(); registry.registerFilter(); diff --git a/src/Speech/Recognizer.cc b/src/Speech/Recognizer.cc index 324de133d..036023ffc 100644 --- a/src/Speech/Recognizer.cc +++ b/src/Speech/Recognizer.cc @@ -45,6 +45,7 @@ const Core::Choice Recognizer::searchTypeChoice_( "advanced-tree-search", Search::AdvancedTreeSearch, "expanding-fsa-search", Search::ExpandingFsaSearchType, "linear-search", Search::LinearSearchType, + "generic-seq2seq-tree-search", Search::GenericSeq2SeqTreeSearchType, Core::Choice::endMark()); const Core::ParameterChoice Recognizer::paramSearch( diff --git a/src/Tensorflow/Tensor.cc b/src/Tensorflow/Tensor.cc index 959a18925..b6476fb71 100644 --- a/src/Tensorflow/Tensor.cc +++ b/src/Tensorflow/Tensor.cc @@ -32,6 +32,7 @@ template struct ToDataType; template struct ToDataType; template struct ToDataType; template struct ToDataType; +template struct ToDataType; // tf::DataTypeToEnum does not have entries for s64 and u64 (as they are (unsigned) long) // instead it has entries for long long and unsigned long long. For our supported data-model @@ -215,6 +216,11 @@ Tensor Tensor::concat(Tensor const& a, Tensor const& b, int axis) { dynamic_rank_concat::Type>(*res.tensor_, *a.tensor_, *b.tensor_, axis); return res; } + case tf::DT_BOOL: { + Tensor res = Tensor::zeros::Type>(new_shape); + dynamic_rank_concat::Type>(*res.tensor_, *a.tensor_, *b.tensor_, axis); + return res; + } default: defect(); } } @@ -242,6 +248,8 @@ Tensor Tensor::concat(const std::vector& inputs, int axis) { return Tensor::concat::Type>(inputs, axis); case tf::DT_UINT8: return Tensor::concat::Type>(inputs, axis); + case tf::DT_BOOL: + return Tensor::concat::Type>(inputs, axis); default: defect(); } } @@ -338,6 +346,7 @@ template Tensor Tensor::concat(const std::vector& inputs, in template Tensor Tensor::concat(const std::vector& inputs, int axis); template Tensor Tensor::concat(const std::vector& inputs, int axis); template Tensor Tensor::concat(const std::vector& inputs, int axis); +template Tensor Tensor::concat(const std::vector& inputs, int axis); /* ------------------------- Getters ------------------------- */ @@ -901,6 +910,7 @@ Tensor Tensor::slice(std::vector const& start, std::vector const& end) case tf::DT_UINT16: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; case tf::DT_INT8: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; case tf::DT_UINT8: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; + case tf::DT_BOOL: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; default: defect(); } diff --git a/src/Test/Makefile b/src/Test/Makefile index b9f3f2fcf..7de6c4d37 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -85,6 +85,9 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH UNIT_TEST_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +UNIT_TEST_O += ../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON UNIT_TEST_O += ../Python/libSprintPython.$(a) endif diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index e9000e78a..f31539366 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -30,6 +30,9 @@ ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Flow/libSprintFlow.$(a) \ ../../Fsa/libSprintFsa.$(a) +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +ARCHIVER_O += ../../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON ARCHIVER_O += ../../Python/libSprintPython.$(a) endif diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index fc71a4bd2..3b1be0d24 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -43,6 +43,9 @@ endif ifdef MODULE_SEARCH_WFST NN_TRAINER_O += ../../Search/Wfst/libSprintSearchWfst.$(a) ../../OpenFst/libSprintOpenFst.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +NN_TRAINER_O += ../../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON NN_TRAINER_O += ../../Python/libSprintPython.$(a) endif From bfa01d85336a3df57182d64a6cf0e73063ced250 Mon Sep 17 00:00:00 2001 From: Simon-Berger Date: Tue, 11 Apr 2023 13:40:48 +0200 Subject: [PATCH 2/3] Implement suggestions --- src/Bliss/Lexicon.hh | 8 +-- src/Lattice/Lattice.cc | 57 +++++++++++-------- src/Lattice/Lattice.hh | 14 ++++- src/Nn/LabelScorer.cc | 2 - src/Nn/LabelScorer.hh | 2 - src/Nn/TFLabelScorer.cc | 2 - src/Nn/TFLabelScorer.hh | 2 - .../AdvancedTreeSearch/AdvancedTreeSearch.cc | 2 +- .../Seq2SeqAligner.cc | 2 - .../Seq2SeqAligner.hh | 2 - .../Seq2SeqTreeSearch.cc | 2 - .../Seq2SeqTreeSearch.hh | 2 - src/Search/Histogram.hh | 4 +- src/Speech/AlignmentNode.hh | 2 +- src/Speech/ModelCombination.cc | 13 +++-- src/Speech/ModelCombination.hh | 4 +- 16 files changed, 62 insertions(+), 58 deletions(-) diff --git a/src/Bliss/Lexicon.hh b/src/Bliss/Lexicon.hh index 47563327a..215bdc406 100644 --- a/src/Bliss/Lexicon.hh +++ b/src/Bliss/Lexicon.hh @@ -642,7 +642,7 @@ public: LemmaIterator(lemmas_.end())); } -#if defined(OBSOLETE) || defined(MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH) +#ifdef OBSOLETE /** * Find a lemma via ID number. * Remember that lemma IDs must not be associated with a @@ -657,13 +657,13 @@ public: */ const Lemma* lemma(Lemma::Id id) const { const Lemma* result = 0; - if (id >=0 && (u32)id < lemmas_.size()) { - result = (const Lemma*)(lemmas_[id]); + if (id < lemmas_.size()) { + result = lemmas_[id]; ensure(result->id() == id); } return result; } -#endif +#endif // OBSOLETE /** * Find a lemma via ID string. This name can either be given diff --git a/src/Lattice/Lattice.cc b/src/Lattice/Lattice.cc index 6a33fcf8a..2be8a6995 100644 --- a/src/Lattice/Lattice.cc +++ b/src/Lattice/Lattice.cc @@ -193,27 +193,32 @@ Speech::TimeframeIndex WordLattice::maximumTime() const { return d.getMaximumTime(); } -StandardWordLattice::StandardWordLattice(Core::Ref lexicon, - bool useLemmaAlphabet) { +StandardWordLattice::StandardWordLattice(Core::Ref lexicon, + AlphabetType alphabetType) + : alphabetType_(alphabetType) { parts_.addChoice("acoustic", 0); parts_.addChoice("lm", 1); acoustic_ = Core::ref(new Fsa::StaticAutomaton); - acoustic_->setType(Fsa::TypeAcceptor); - if (useLemmaAlphabet) { - acoustic_->setInputAlphabet(lexicon->lemmaAlphabet()); - } else { - acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); - } - acoustic_->setSemiring(Fsa::TropicalSemiring); + lm_ = Core::ref(new Fsa::StaticAutomaton); - lm_ = Core::ref(new Fsa::StaticAutomaton); + acoustic_->setType(Fsa::TypeAcceptor); lm_->setType(Fsa::TypeAcceptor); - if (useLemmaAlphabet) { - lm_->setInputAlphabet(lexicon->lemmaAlphabet()); - } else { - lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + + switch (alphabetType_) { + case AlphabetType::LemmaAlphabet: + acoustic_->setInputAlphabet(lexicon->lemmaAlphabet()); + lm_->setInputAlphabet(lexicon->lemmaAlphabet()); + break; + case AlphabetType::LemmaPronunciationAlphabet: + acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + break; + default: + defect(); } + + acoustic_->setSemiring(Fsa::TropicalSemiring); lm_->setSemiring(Fsa::TropicalSemiring); fsas_.push_back(acoustic_); @@ -235,17 +240,22 @@ Fsa::State* StandardWordLattice::newState() { return acoustic_->newState(); } +void StandardWordLattice::newArc( + Fsa::State* source, + Fsa::State* target, + Fsa::LabelId id, + Speech::Score acoustic, Speech::Score lm) { + source -> newArc(target->id(), Fsa::Weight(acoustic), id); + lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), id); +} + void StandardWordLattice::newArc( Fsa::State* source, Fsa::State* target, const Bliss::LemmaPronunciation* lemmaPronunciation, Speech::Score acoustic, Speech::Score lm) { - source->newArc( - target->id(), - Fsa::Weight(acoustic), - (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon); - - lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon); + Fsa::LabelId id = lemmaPronunciation ? lemmaPronunciation->id() : Fsa::Epsilon; + newArc(source, target, id, acoustic, lm); } void StandardWordLattice::newArc( @@ -253,11 +263,8 @@ void StandardWordLattice::newArc( Fsa::State* target, const Bliss::Lemma* lemma, Speech::Score acoustic, Speech::Score lm) { - source->newArc(target->id(), Fsa::Weight(acoustic), - (lemma) ? lemma->id() : Fsa::Epsilon); - - lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), - (lemma) ? lemma->id() : Fsa::Epsilon); + Fsa::LabelId id = lemma ? lemma->id() : Fsa::Epsilon; + newArc(source, target, id, acoustic, lm); } void StandardWordLattice::addAcyclicProperty() { diff --git a/src/Lattice/Lattice.hh b/src/Lattice/Lattice.hh index 11d755f65..2f7318269 100644 --- a/src/Lattice/Lattice.hh +++ b/src/Lattice/Lattice.hh @@ -268,12 +268,24 @@ public: * near future, when the WordLattice interface is refactored. */ class StandardWordLattice : public WordLattice { +public: + enum class AlphabetType { + LemmaAlphabet = 1, + LemmaPronunciationAlphabet = 2 + }; + private: Core::Ref acoustic_, lm_; Fsa::State * initialState_, *finalState_; + AlphabetType alphabetType_; + + void newArc(Fsa::State* source, + Fsa::State* target, + Fsa::LabelId id, + Speech::Score acoustic, Speech::Score lm); public: - StandardWordLattice(Bliss::LexiconRef, bool useLemmaAlphabet = false); + StandardWordLattice(Bliss::LexiconRef, AlphabetType alphabetType = AlphabetType::LemmaAlphabet); Fsa::State* newState(); Fsa::State* initialState() { diff --git a/src/Nn/LabelScorer.cc b/src/Nn/LabelScorer.cc index 6fc1f0721..ff1d531f1 100644 --- a/src/Nn/LabelScorer.cc +++ b/src/Nn/LabelScorer.cc @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #include "LabelScorer.hh" diff --git a/src/Nn/LabelScorer.hh b/src/Nn/LabelScorer.hh index 5f8b37821..38d8ea6c2 100644 --- a/src/Nn/LabelScorer.hh +++ b/src/Nn/LabelScorer.hh @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #ifndef LABEL_SCORER_HH diff --git a/src/Nn/TFLabelScorer.cc b/src/Nn/TFLabelScorer.cc index 0c07418f5..8fd296021 100644 --- a/src/Nn/TFLabelScorer.cc +++ b/src/Nn/TFLabelScorer.cc @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #include "TFLabelScorer.hh" diff --git a/src/Nn/TFLabelScorer.hh b/src/Nn/TFLabelScorer.hh index bf53b1140..76d20ffbc 100644 --- a/src/Nn/TFLabelScorer.hh +++ b/src/Nn/TFLabelScorer.hh @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #ifndef TF_LABEL_SCORER_HH diff --git a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc index 3e3a17142..67e3c14ff 100644 --- a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc +++ b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc @@ -471,7 +471,7 @@ Core::Ref AdvancedTreeSearchManager::buildLatticeForTrace( wordBoundaries->set((Fsa::StateId) /* WARNING: this looses a few bits */ (long)result->initialState(), Lattice::WordBoundary(0)); Fsa::State* finalState = result->newState(); wordBoundaries->set(finalState->id(), Lattice::WordBoundary(time_)); - result->newArc(result->initialState(), finalState, (const Bliss::LemmaPronunciation*)0, 0, 0); + result->newArc(result->initialState(), finalState, static_cast(nullptr), 0, 0); result->setWordBoundaries(wordBoundaries); result->addAcyclicProperty(); return Core::ref(new Lattice::WordLatticeAdaptor(result)); diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc index de20eca30..3d96b1657 100644 --- a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #include "Seq2SeqAligner.hh" diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh index da168c8f9..729e600d6 100644 --- a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #ifndef SEQ2SEQ_ALIGNER_HH diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc index 355dd987a..11fb8a664 100644 --- a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #include "Seq2SeqTreeSearch.hh" diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh index 45112ce37..38313e9af 100644 --- a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #ifndef SEQ2SEQ_TREE_SEARCH_HH diff --git a/src/Search/Histogram.hh b/src/Search/Histogram.hh index 02007dabd..0bcaa2122 100644 --- a/src/Search/Histogram.hh +++ b/src/Search/Histogram.hh @@ -64,7 +64,7 @@ public: void operator+=(Score s) { bins_[bin(s)] += 1; } - Score quantile(Count nn, bool safe=false) const { + Score quantile(Count nn) const { Bin b = 0; for (s32 n = nn; b < bins_.size(); ++b) { // n must be signed! n -= bins_[b]; @@ -72,8 +72,6 @@ public: break; } verify(b <= bins_.size()); - if (b == 0 && safe) - b = 1; // at least one bin, otherwise only single best left Score result = Score(b) / scale_ + lower_; ensure(lower_ <= result); ensure(result < upper_ + 2.0 / scale_); diff --git a/src/Speech/AlignmentNode.hh b/src/Speech/AlignmentNode.hh index d3bd47721..551746835 100644 --- a/src/Speech/AlignmentNode.hh +++ b/src/Speech/AlignmentNode.hh @@ -125,7 +125,7 @@ class Seq2SeqAlignmentNode : public AlignmentBaseNode { public: static std::string filterName() { - return "speech-seq2seq-alignment"; + return "speech-generic-seq2seq-alignment"; } public: diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index e7d8ca106..d68664fe9 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -20,10 +20,11 @@ using namespace Speech; -const ModelCombination::Mode ModelCombination::complete = 0xFFFF; +const ModelCombination::Mode ModelCombination::complete = 0x3; const ModelCombination::Mode ModelCombination::useLexicon = 0x0; const ModelCombination::Mode ModelCombination::useAcousticModel = 0x1; const ModelCombination::Mode ModelCombination::useLanguageModel = 0x2; +const ModelCombination::Mode ModelCombination::useLabelScorer = 0x4; //====================================================================================== @@ -55,6 +56,12 @@ ModelCombination::ModelCombination(const Core::Configuration& c, if (!languageModel_) criticalError("failed to initialize language model"); } + + if (mode & useLabelScorer) { + setLabelScorer(Nn::Module::instance().createLabelScorer(select("label-scorer"))); + if (!labelScorer_) + criticalError("failed to initialize label scorer"); + } } ModelCombination::ModelCombination(const Core::Configuration& c, @@ -114,7 +121,3 @@ void ModelCombination::getDependencies(Core::DependencySet& dependencies) const dependencies.add(name(), d); Mc::Component::getDependencies(dependencies); } - -void ModelCombination::createLabelScorer() { - labelScorer_ = Nn::Module::instance().createLabelScorer(select("label-scorer")); -} diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index 2400e9279..7d2b15f3f 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -43,6 +43,7 @@ public: static const Mode useLexicon; static const Mode useAcousticModel; static const Mode useLanguageModel; + static const Mode useLabelScorer; static const Core::ParameterFloat paramPronunciationScale; @@ -87,8 +88,7 @@ public: } void setLanguageModel(Core::Ref); - void createLabelScorer(); - void setLabelScorer(Core::Ref& ls) { labelScorer_ = ls; } + void setLabelScorer(Core::Ref ls) { labelScorer_ = ls; } Core::Ref labelScorer() const { return labelScorer_; } }; From cbdcfae4767c017291c50332d3ba0b6798a4ed4e Mon Sep 17 00:00:00 2001 From: Simon-Berger Date: Mon, 17 Apr 2023 10:10:49 +0200 Subject: [PATCH 3/3] Add comment for 'complete' mode. --- src/Speech/ModelCombination.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index 7d2b15f3f..ee24267c8 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -39,7 +39,7 @@ namespace Speech { class ModelCombination : public Mc::Component, public Core::ReferenceCounted { public: typedef u32 Mode; - static const Mode complete; + static const Mode complete; // Includes lexicon, AM and LM but NOT label scorer; named 'complete' for legacy reasons. static const Mode useLexicon; static const Mode useAcousticModel; static const Mode useLanguageModel;