From 3c93cdbba386fd3946d2f3b6b8f7ff20aec17b9a Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Mon, 17 Apr 2023 12:25:52 +0200 Subject: [PATCH] Generic Seq2Seq Decoder preparation (#49) --- Modules.make | 10 ++-- src/Flf/SegmentwiseSpeechProcessor.cc | 7 ++- src/Flf/SegmentwiseSpeechProcessor.hh | 2 +- src/Lattice/Lattice.cc | 50 +++++++++++++---- src/Lattice/Lattice.hh | 19 ++++++- src/Nn/LabelScorer.cc | 18 +++++++ src/Nn/LabelScorer.hh | 32 +++++++++++ src/Nn/Makefile | 11 ++++ src/Nn/Module.cc | 17 ++++++ src/Nn/Module.hh | 4 ++ src/Nn/TFLabelScorer.cc | 18 +++++++ src/Nn/TFLabelScorer.hh | 27 ++++++++++ .../AdvancedTreeSearch/AdvancedTreeSearch.cc | 2 +- src/Search/GenericSeq2SeqTreeSearch/Makefile | 36 +++++++++++++ .../Seq2SeqAligner.cc | 18 +++++++ .../Seq2SeqAligner.hh | 31 +++++++++++ .../Seq2SeqTreeSearch.cc | 49 +++++++++++++++++ .../Seq2SeqTreeSearch.hh | 54 +++++++++++++++++++ src/Search/Makefile | 5 ++ src/Search/Module.cc | 12 +++++ src/Search/Module.hh | 1 + src/Speech/AlignmentNode.cc | 19 +++++++ src/Speech/AlignmentNode.hh | 22 ++++++++ src/Speech/Makefile | 3 ++ src/Speech/ModelCombination.cc | 11 +++- src/Speech/ModelCombination.hh | 11 +++- src/Speech/Module.cc | 4 ++ src/Speech/Recognizer.cc | 1 + src/Tensorflow/Tensor.cc | 10 ++++ src/Test/Makefile | 3 ++ src/Tools/Archiver/Makefile | 3 ++ src/Tools/NnTrainer/Makefile | 3 ++ 32 files changed, 491 insertions(+), 22 deletions(-) create mode 100644 src/Nn/LabelScorer.cc create mode 100644 src/Nn/LabelScorer.hh create mode 100644 src/Nn/TFLabelScorer.cc create mode 100644 src/Nn/TFLabelScorer.hh create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Makefile create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc create mode 100644 src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh diff --git a/Modules.make b/Modules.make index ee37a45cc..971a91bc7 100644 --- a/Modules.make +++ b/Modules.make @@ -69,13 +69,14 @@ MODULES += MODULE_THEANO_INTERFACE MODULES += MODULE_PYTHON # ****** OpenFst ****** - MODULES += MODULE_OPENFST +MODULES += MODULE_OPENFST # ****** Search ****** - MODULES += MODULE_SEARCH_MBR - MODULES += MODULE_SEARCH_WFST +MODULES += MODULE_SEARCH_MBR +MODULES += MODULE_SEARCH_WFST MODULES += MODULE_SEARCH_LINEAR MODULES += MODULE_ADVANCED_TREE_SEARCH +MODULES += MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH # ****** Signal ****** MODULES += MODULE_SIGNAL_GAMMATONE @@ -151,3 +152,6 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH LIBS_SEARCH += src/Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +LIBS_SEARCH += src/Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif diff --git a/src/Flf/SegmentwiseSpeechProcessor.cc b/src/Flf/SegmentwiseSpeechProcessor.cc index edf478d06..b7dc35b90 100644 --- a/src/Flf/SegmentwiseSpeechProcessor.cc +++ b/src/Flf/SegmentwiseSpeechProcessor.cc @@ -20,8 +20,11 @@ namespace Flf { // ------------------------------------------------------------------------- -AcousticModelRef getAm(const Core::Configuration& config) { - return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us())); +AcousticModelRef getAm(const Core::Configuration& config, bool useMixture) { + if (useMixture) + return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us())); + else + return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()), Am::AcousticModel::noEmissions); } ScaledLanguageModelRef getLm(const Core::Configuration& config) { diff --git a/src/Flf/SegmentwiseSpeechProcessor.hh b/src/Flf/SegmentwiseSpeechProcessor.hh index 759f6bddb..8c9f271c1 100644 --- a/src/Flf/SegmentwiseSpeechProcessor.hh +++ b/src/Flf/SegmentwiseSpeechProcessor.hh @@ -38,7 +38,7 @@ typedef Core::Ref LanguageModelRef; typedef Core::Ref ScaledLanguageModelRef; typedef Core::Ref ModelCombinationRef; -AcousticModelRef getAm(const Core::Configuration& config); +AcousticModelRef getAm(const Core::Configuration& config, bool useMixture=true); ScaledLanguageModelRef getLm(const Core::Configuration& config); ModelCombinationRef getModelCombination(const Core::Configuration& config, AcousticModelRef acousticModel, ScaledLanguageModelRef languageModel = ScaledLanguageModelRef()); diff --git a/src/Lattice/Lattice.cc b/src/Lattice/Lattice.cc index 5940c368c..2be8a6995 100644 --- a/src/Lattice/Lattice.cc +++ b/src/Lattice/Lattice.cc @@ -193,18 +193,32 @@ Speech::TimeframeIndex WordLattice::maximumTime() const { return d.getMaximumTime(); } -StandardWordLattice::StandardWordLattice(Core::Ref lexicon) { +StandardWordLattice::StandardWordLattice(Core::Ref lexicon, + AlphabetType alphabetType) + : alphabetType_(alphabetType) { parts_.addChoice("acoustic", 0); parts_.addChoice("lm", 1); acoustic_ = Core::ref(new Fsa::StaticAutomaton); - acoustic_->setType(Fsa::TypeAcceptor); - acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); - acoustic_->setSemiring(Fsa::TropicalSemiring); + lm_ = Core::ref(new Fsa::StaticAutomaton); - lm_ = Core::ref(new Fsa::StaticAutomaton); + acoustic_->setType(Fsa::TypeAcceptor); lm_->setType(Fsa::TypeAcceptor); - lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + + switch (alphabetType_) { + case AlphabetType::LemmaAlphabet: + acoustic_->setInputAlphabet(lexicon->lemmaAlphabet()); + lm_->setInputAlphabet(lexicon->lemmaAlphabet()); + break; + case AlphabetType::LemmaPronunciationAlphabet: + acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet()); + break; + default: + defect(); + } + + acoustic_->setSemiring(Fsa::TropicalSemiring); lm_->setSemiring(Fsa::TropicalSemiring); fsas_.push_back(acoustic_); @@ -226,17 +240,31 @@ Fsa::State* StandardWordLattice::newState() { return acoustic_->newState(); } +void StandardWordLattice::newArc( + Fsa::State* source, + Fsa::State* target, + Fsa::LabelId id, + Speech::Score acoustic, Speech::Score lm) { + source -> newArc(target->id(), Fsa::Weight(acoustic), id); + lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), id); +} + void StandardWordLattice::newArc( Fsa::State* source, Fsa::State* target, const Bliss::LemmaPronunciation* lemmaPronunciation, Speech::Score acoustic, Speech::Score lm) { - source->newArc( - target->id(), - Fsa::Weight(acoustic), - (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon); + Fsa::LabelId id = lemmaPronunciation ? lemmaPronunciation->id() : Fsa::Epsilon; + newArc(source, target, id, acoustic, lm); +} - lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon); +void StandardWordLattice::newArc( + Fsa::State* source, + Fsa::State* target, + const Bliss::Lemma* lemma, + Speech::Score acoustic, Speech::Score lm) { + Fsa::LabelId id = lemma ? lemma->id() : Fsa::Epsilon; + newArc(source, target, id, acoustic, lm); } void StandardWordLattice::addAcyclicProperty() { diff --git a/src/Lattice/Lattice.hh b/src/Lattice/Lattice.hh index d0a386349..2f7318269 100644 --- a/src/Lattice/Lattice.hh +++ b/src/Lattice/Lattice.hh @@ -268,12 +268,24 @@ public: * near future, when the WordLattice interface is refactored. */ class StandardWordLattice : public WordLattice { +public: + enum class AlphabetType { + LemmaAlphabet = 1, + LemmaPronunciationAlphabet = 2 + }; + private: Core::Ref acoustic_, lm_; Fsa::State * initialState_, *finalState_; + AlphabetType alphabetType_; + + void newArc(Fsa::State* source, + Fsa::State* target, + Fsa::LabelId id, + Speech::Score acoustic, Speech::Score lm); public: - StandardWordLattice(Bliss::LexiconRef); + StandardWordLattice(Bliss::LexiconRef, AlphabetType alphabetType = AlphabetType::LemmaAlphabet); Fsa::State* newState(); Fsa::State* initialState() { @@ -288,6 +300,11 @@ public: const Bliss::LemmaPronunciation*, Speech::Score acoustic, Speech::Score lm); + void newArc(Fsa::State *source, + Fsa::State *target, + const Bliss::Lemma*, + Speech::Score acoustic, Speech::Score lm); + void addAcyclicProperty(); }; diff --git a/src/Nn/LabelScorer.cc b/src/Nn/LabelScorer.cc new file mode 100644 index 000000000..ff1d531f1 --- /dev/null +++ b/src/Nn/LabelScorer.cc @@ -0,0 +1,18 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LabelScorer.hh" + +using namespace Nn; diff --git a/src/Nn/LabelScorer.hh b/src/Nn/LabelScorer.hh new file mode 100644 index 000000000..38d8ea6c2 --- /dev/null +++ b/src/Nn/LabelScorer.hh @@ -0,0 +1,32 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LABEL_SCORER_HH +#define LABEL_SCORER_HH + +#include +#include + +namespace Nn { + + +// base class of models for label scoring (basic supports except scoring) +class LabelScorer : public virtual Core::Component, + public Core::ReferenceCounted { +}; + +} // namespace Nn + +#endif diff --git a/src/Nn/Makefile b/src/Nn/Makefile index c141b9387..a3723a974 100644 --- a/src/Nn/Makefile +++ b/src/Nn/Makefile @@ -77,6 +77,17 @@ ifdef MODULE_PYTHON LIBSPRINTNN_O += $(OBJDIR)/PythonLayer.o endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + LIBSPRINTNN_O += $(OBJDIR)/LabelScorer.o + ifdef MODULE_TENSORFLOW + LIBSPRINTNN_O += $(OBJDIR)/TFLabelScorer.o + + CXXFLAGS += $(TF_CXXFLAGS) + LDFLAGS += $(TF_LDFLAGS) + CHECK_O += ../Tensorflow/libSprintTensorflow.$(a) + endif +endif + # ----------------------------------------------------------------------------- all: $(TARGETS) diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index d2bcc08f7..efb1485d5 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -34,6 +34,13 @@ #include "PythonFeatureScorer.hh" #endif +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include "LabelScorer.hh" +#ifdef MODULE_TENSORFLOW +#include "TFLabelScorer.hh" +#endif +#endif + using namespace Nn; Module_::Module_() @@ -80,3 +87,13 @@ Core::FormatSet& Module_::formats() { } return *formats_; } + + +Core::Ref Module_::createLabelScorer(const Core::Configuration& config) const { +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + LabelScorer* labelScorer = nullptr; + return Core::ref(labelScorer); +#else + Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); +#endif +} diff --git a/src/Nn/Module.hh b/src/Nn/Module.hh index bb923ba2b..304d27a61 100644 --- a/src/Nn/Module.hh +++ b/src/Nn/Module.hh @@ -26,6 +26,8 @@ class FormatSet; namespace Nn { +class LabelScorer; + class Module_ { private: Core::FormatSet* formats_; @@ -47,6 +49,8 @@ public: /** Set of file format class. */ Core::FormatSet& formats(); + + Core::Ref createLabelScorer(const Core::Configuration& config) const; }; typedef Core::SingletonHolder Module; diff --git a/src/Nn/TFLabelScorer.cc b/src/Nn/TFLabelScorer.cc new file mode 100644 index 000000000..8fd296021 --- /dev/null +++ b/src/Nn/TFLabelScorer.cc @@ -0,0 +1,18 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLabelScorer.hh" + +using namespace Nn; diff --git a/src/Nn/TFLabelScorer.hh b/src/Nn/TFLabelScorer.hh new file mode 100644 index 000000000..76d20ffbc --- /dev/null +++ b/src/Nn/TFLabelScorer.hh @@ -0,0 +1,27 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TF_LABEL_SCORER_HH +#define TF_LABEL_SCORER_HH + +#include "LabelScorer.hh" + +namespace Nn { + + +} // namesapce + +#endif + diff --git a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc index 60fe42e6b..67e3c14ff 100644 --- a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc +++ b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc @@ -471,7 +471,7 @@ Core::Ref AdvancedTreeSearchManager::buildLatticeForTrace( wordBoundaries->set((Fsa::StateId) /* WARNING: this looses a few bits */ (long)result->initialState(), Lattice::WordBoundary(0)); Fsa::State* finalState = result->newState(); wordBoundaries->set(finalState->id(), Lattice::WordBoundary(time_)); - result->newArc(result->initialState(), finalState, 0, 0, 0); + result->newArc(result->initialState(), finalState, static_cast(nullptr), 0, 0); result->setWordBoundaries(wordBoundaries); result->addAcyclicProperty(); return Core::ref(new Lattice::WordLatticeAdaptor(result)); diff --git a/src/Search/GenericSeq2SeqTreeSearch/Makefile b/src/Search/GenericSeq2SeqTreeSearch/Makefile new file mode 100644 index 000000000..5f281bf53 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Makefile @@ -0,0 +1,36 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintGenericSeq2SeqTreeSearch.$(a) + +LIBSPRINTSEQ2SEQTREESEARCH_O = $(OBJDIR)/Seq2SeqTreeSearch.o \ + $(OBJDIR)/Seq2SeqAligner.o + +ifeq ($(OS),darwin) +CCFLAGS += -fexceptions +endif + +# These flags make the compilation slow, but are required to inline some critical functions in SearchSpace (copied from AdvancedTreeSearch) +CCFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400 +CXXFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400 + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintGenericSeq2SeqTreeSearch.$(a): $(LIBSPRINTSEQ2SEQTREESEARCH_O) + $(MAKELIB) $@ $^ + +check$(exe): $(CHECK_O) + $(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS) + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTSEQ2SEQTREESEARCH_O:.o=.d) +include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O))) diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc new file mode 100644 index 000000000..3d96b1657 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc @@ -0,0 +1,18 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Seq2SeqAligner.hh" + +using namespace Search; diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh new file mode 100644 index 000000000..729e600d6 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.hh @@ -0,0 +1,31 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SEQ2SEQ_ALIGNER_HH +#define SEQ2SEQ_ALIGNER_HH + +#include + +namespace Search { + + +// Integrated alignment interface and search space +// So far: Viterbi only +class Seq2SeqAligner : public Core::Component { +}; + +} // namespace + +#endif // SEQ2SEQ_ALIGNER_HH diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc new file mode 100644 index 000000000..11fb8a664 --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.cc @@ -0,0 +1,49 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Seq2SeqTreeSearch.hh" + +using namespace Search; + + +Seq2SeqTreeSearchManager::Seq2SeqTreeSearchManager(const Core::Configuration& c) + : Core::Component(c), + SearchAlgorithm(c) { +} + + +bool Seq2SeqTreeSearchManager::setModelCombination(const Speech::ModelCombination& modelCombination) { +} + + +void Seq2SeqTreeSearchManager::setGrammar(Fsa::ConstAutomatonRef g) { +} + +void Seq2SeqTreeSearchManager::resetStatistics() { +} + +void Seq2SeqTreeSearchManager::logStatistics() const { +} + +void Seq2SeqTreeSearchManager::restart() { +} + + +void Seq2SeqTreeSearchManager::getCurrentBestSentence(Traceback &result) const { +} + +Core::Ref Seq2SeqTreeSearchManager::getCurrentWordLattice() const { +} + diff --git a/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh new file mode 100644 index 000000000..38313e9af --- /dev/null +++ b/src/Search/GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh @@ -0,0 +1,54 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SEQ2SEQ_TREE_SEARCH_HH +#define SEQ2SEQ_TREE_SEARCH_HH + +#include +#include + +// search manager: interface between search space and high level recognizer +// - manage step-wise decoding: expansion, pruning, recombination, etc. +// - results pulling (traceback) + +namespace Search { + + +class Seq2SeqTreeSearchManager : public SearchAlgorithm { + + public: + Seq2SeqTreeSearchManager(const Core::Configuration &); + + // ---- from SearchAlgorithm (overwrite required) ---- + virtual bool setModelCombination(const Speech::ModelCombination& modelCombination); + virtual void setGrammar(Fsa::ConstAutomatonRef); + virtual void restart(); + + // replaced by decode and decodeNext + virtual void feed(const Mm::FeatureScorer::Scorer&) {} + + // TODO partial result not supported yet + + virtual void getCurrentBestSentence(Traceback &result) const; + virtual Core::Ref getCurrentWordLattice() const; + + virtual void resetStatistics(); + virtual void logStatistics() const; + }; + +} // namespace + +#endif + diff --git a/src/Search/Makefile b/src/Search/Makefile index abeb7774d..11514593b 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -41,6 +41,9 @@ endif ifdef MODULE_SEARCH_LINEAR LIBSPRINTSEARCH_O += $(OBJDIR)/LinearSearch.o endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +SUBDIRS += GenericSeq2SeqTreeSearch +endif # ----------------------------------------------------------------------------- @@ -62,6 +65,8 @@ Wfst: AdvancedTreeSearch: $(MAKE) -C $@ libSprintAdvancedTreeSearch.$(a) +GenericSeq2SeqTreeSearch: + $(MAKE) -C $@ libSprintGenericSeq2SeqTreeSearch.$(a) include $(TOPDIR)/Rules.make diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 3dfb12aca..2dd703b21 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -26,6 +26,9 @@ #ifdef MODULE_ADVANCED_TREE_SEARCH #include "AdvancedTreeSearch/AdvancedTreeSearch.hh" #endif +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include "GenericSeq2SeqTreeSearch/Seq2SeqTreeSearch.hh" +#endif using namespace Search; @@ -60,6 +63,15 @@ SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configur Core::Application::us()->criticalError("Module MODULE_SEARCH_LINEAR not available!"); #endif break; + + case GenericSeq2SeqTreeSearchType: +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + recognizer = new Search::Seq2SeqTreeSearchManager(config); +#else + Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); +#endif + break; + default: Core::Application::us()->criticalError("unknown recognizer type: %d", type); break; diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 04b3c33fb..573a3a7d6 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -28,6 +28,7 @@ enum SearchType { AdvancedTreeSearch, LinearSearchType, ExpandingFsaSearchType, + GenericSeq2SeqTreeSearchType }; class Module_ { diff --git a/src/Speech/AlignmentNode.cc b/src/Speech/AlignmentNode.cc index 4d0a0dbe5..318e63f74 100644 --- a/src/Speech/AlignmentNode.cc +++ b/src/Speech/AlignmentNode.cc @@ -374,6 +374,25 @@ void AlignmentNode::logTraceback(Lattice::ConstWordLatticeRef wordLattice) const tracebackChannel_ << Core::XmlClose("traceback"); } + +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +/** Seq2SeqAlignmentNode +*/ + +Seq2SeqAlignmentNode::Seq2SeqAlignmentNode(const Core::Configuration& c) : + Core::Component(c), + Precursor(c) { +} + + +void Seq2SeqAlignmentNode::createModel() { +} + +bool Seq2SeqAlignmentNode::work(Flow::PortId p) { +} +#endif + + /** AlignmentDumpNode */ const Core::ParameterString AlignmentDumpNode::paramFilename( diff --git a/src/Speech/AlignmentNode.hh b/src/Speech/AlignmentNode.hh index 201a86ab2..551746835 100644 --- a/src/Speech/AlignmentNode.hh +++ b/src/Speech/AlignmentNode.hh @@ -18,6 +18,9 @@ #include #include #include +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +#include +#endif #include "Alignment.hh" #include "ModelCombination.hh" @@ -115,6 +118,25 @@ public: virtual bool work(Flow::PortId); }; +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +/** Seq2Seq AlignmentNode */ +class Seq2SeqAlignmentNode : public AlignmentBaseNode { + typedef AlignmentBaseNode Precursor; + +public: + static std::string filterName() { + return "speech-generic-seq2seq-alignment"; + } + +public: + Seq2SeqAlignmentNode(const Core::Configuration&); + virtual bool work(Flow::PortId); + +protected: + void createModel(); +}; +#endif + /** Dumps alignments in a plain text format */ class AlignmentDumpNode : public Flow::Node { typedef Flow::Node Precursor; diff --git a/src/Speech/Makefile b/src/Speech/Makefile index 4b5ab3528..60bee95d3 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -143,6 +143,9 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH CHECK_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +CHECK_O += ../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON CHECK_O += ../Python/libSprintPython.$(a) endif diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index 7ab089d27..d68664fe9 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -15,13 +15,16 @@ #include "ModelCombination.hh" #include #include +#include + using namespace Speech; -const ModelCombination::Mode ModelCombination::complete = 0xFFFF; +const ModelCombination::Mode ModelCombination::complete = 0x3; const ModelCombination::Mode ModelCombination::useLexicon = 0x0; const ModelCombination::Mode ModelCombination::useAcousticModel = 0x1; const ModelCombination::Mode ModelCombination::useLanguageModel = 0x2; +const ModelCombination::Mode ModelCombination::useLabelScorer = 0x4; //====================================================================================== @@ -53,6 +56,12 @@ ModelCombination::ModelCombination(const Core::Configuration& c, if (!languageModel_) criticalError("failed to initialize language model"); } + + if (mode & useLabelScorer) { + setLabelScorer(Nn::Module::instance().createLabelScorer(select("label-scorer"))); + if (!labelScorer_) + criticalError("failed to initialize label scorer"); + } } ModelCombination::ModelCombination(const Core::Configuration& c, diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index f00e4a249..ee24267c8 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -21,10 +21,12 @@ #include #include #include +#include + namespace Speech { -/** Combination of a lexicon, an acoustic model, and a language model. +/** Combination of a lexicon, an acoustic model or label scorer, and a language model. * It supports creation and initialization of these three mutually dependent objects. * * Usage: @@ -37,10 +39,11 @@ namespace Speech { class ModelCombination : public Mc::Component, public Core::ReferenceCounted { public: typedef u32 Mode; - static const Mode complete; + static const Mode complete; // Includes lexicon, AM and LM but NOT label scorer; named 'complete' for legacy reasons. static const Mode useLexicon; static const Mode useAcousticModel; static const Mode useLanguageModel; + static const Mode useLabelScorer; static const Core::ParameterFloat paramPronunciationScale; @@ -49,6 +52,7 @@ protected: Mm::Score pronunciationScale_; Core::Ref acousticModel_; Core::Ref languageModel_; + Core::Ref labelScorer_; private: void setPronunciationScale(Mm::Score scale) { @@ -83,6 +87,9 @@ public: return languageModel_; } void setLanguageModel(Core::Ref); + + void setLabelScorer(Core::Ref ls) { labelScorer_ = ls; } + Core::Ref labelScorer() const { return labelScorer_; } }; typedef Core::Ref ModelCombinationRef; diff --git a/src/Speech/Module.cc b/src/Speech/Module.cc index fa31fa230..e55993655 100644 --- a/src/Speech/Module.cc +++ b/src/Speech/Module.cc @@ -65,6 +65,10 @@ Module_::Module_() { registry.registerFilter(); registry.registerDatatype>(); +#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH + registry.registerFilter(); +#endif + #ifdef MODULE_SPEECH_ALIGNMENT_FLOW_NODES registry.registerFilter(); registry.registerFilter(); diff --git a/src/Speech/Recognizer.cc b/src/Speech/Recognizer.cc index 324de133d..036023ffc 100644 --- a/src/Speech/Recognizer.cc +++ b/src/Speech/Recognizer.cc @@ -45,6 +45,7 @@ const Core::Choice Recognizer::searchTypeChoice_( "advanced-tree-search", Search::AdvancedTreeSearch, "expanding-fsa-search", Search::ExpandingFsaSearchType, "linear-search", Search::LinearSearchType, + "generic-seq2seq-tree-search", Search::GenericSeq2SeqTreeSearchType, Core::Choice::endMark()); const Core::ParameterChoice Recognizer::paramSearch( diff --git a/src/Tensorflow/Tensor.cc b/src/Tensorflow/Tensor.cc index 959a18925..b6476fb71 100644 --- a/src/Tensorflow/Tensor.cc +++ b/src/Tensorflow/Tensor.cc @@ -32,6 +32,7 @@ template struct ToDataType; template struct ToDataType; template struct ToDataType; template struct ToDataType; +template struct ToDataType; // tf::DataTypeToEnum does not have entries for s64 and u64 (as they are (unsigned) long) // instead it has entries for long long and unsigned long long. For our supported data-model @@ -215,6 +216,11 @@ Tensor Tensor::concat(Tensor const& a, Tensor const& b, int axis) { dynamic_rank_concat::Type>(*res.tensor_, *a.tensor_, *b.tensor_, axis); return res; } + case tf::DT_BOOL: { + Tensor res = Tensor::zeros::Type>(new_shape); + dynamic_rank_concat::Type>(*res.tensor_, *a.tensor_, *b.tensor_, axis); + return res; + } default: defect(); } } @@ -242,6 +248,8 @@ Tensor Tensor::concat(const std::vector& inputs, int axis) { return Tensor::concat::Type>(inputs, axis); case tf::DT_UINT8: return Tensor::concat::Type>(inputs, axis); + case tf::DT_BOOL: + return Tensor::concat::Type>(inputs, axis); default: defect(); } } @@ -338,6 +346,7 @@ template Tensor Tensor::concat(const std::vector& inputs, in template Tensor Tensor::concat(const std::vector& inputs, int axis); template Tensor Tensor::concat(const std::vector& inputs, int axis); template Tensor Tensor::concat(const std::vector& inputs, int axis); +template Tensor Tensor::concat(const std::vector& inputs, int axis); /* ------------------------- Getters ------------------------- */ @@ -901,6 +910,7 @@ Tensor Tensor::slice(std::vector const& start, std::vector const& end) case tf::DT_UINT16: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; case tf::DT_INT8: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; case tf::DT_UINT8: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; + case tf::DT_BOOL: dynamic_rank_slice::Type>(*res.tensor_, *tensor_, start_vec, size_vec); break; default: defect(); } diff --git a/src/Test/Makefile b/src/Test/Makefile index b9f3f2fcf..7de6c4d37 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -85,6 +85,9 @@ endif ifdef MODULE_ADVANCED_TREE_SEARCH UNIT_TEST_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +UNIT_TEST_O += ../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON UNIT_TEST_O += ../Python/libSprintPython.$(a) endif diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index e9000e78a..f31539366 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -30,6 +30,9 @@ ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Flow/libSprintFlow.$(a) \ ../../Fsa/libSprintFsa.$(a) +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +ARCHIVER_O += ../../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON ARCHIVER_O += ../../Python/libSprintPython.$(a) endif diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index fc71a4bd2..3b1be0d24 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -43,6 +43,9 @@ endif ifdef MODULE_SEARCH_WFST NN_TRAINER_O += ../../Search/Wfst/libSprintSearchWfst.$(a) ../../OpenFst/libSprintOpenFst.$(a) endif +ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH +NN_TRAINER_O += ../../Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a) +endif ifdef MODULE_PYTHON NN_TRAINER_O += ../../Python/libSprintPython.$(a) endif