Skip to content

Commit

Permalink
Generic Seq2Seq Decoder preparation (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimBe195 authored and Marvin84 committed Jun 12, 2023
1 parent d50924e commit 3c93cdb
Show file tree
Hide file tree
Showing 32 changed files with 491 additions and 22 deletions.
10 changes: 7 additions & 3 deletions Modules.make
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ MODULES += MODULE_THEANO_INTERFACE
MODULES += MODULE_PYTHON

# ****** OpenFst ******
MODULES += MODULE_OPENFST
MODULES += MODULE_OPENFST

# ****** Search ******
MODULES += MODULE_SEARCH_MBR
MODULES += MODULE_SEARCH_WFST
MODULES += MODULE_SEARCH_MBR
MODULES += MODULE_SEARCH_WFST
MODULES += MODULE_SEARCH_LINEAR
MODULES += MODULE_ADVANCED_TREE_SEARCH
MODULES += MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH

# ****** Signal ******
MODULES += MODULE_SIGNAL_GAMMATONE
Expand Down Expand Up @@ -151,3 +152,6 @@ endif
ifdef MODULE_ADVANCED_TREE_SEARCH
LIBS_SEARCH += src/Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a)
endif
ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LIBS_SEARCH += src/Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a)
endif
7 changes: 5 additions & 2 deletions src/Flf/SegmentwiseSpeechProcessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
namespace Flf {

// -------------------------------------------------------------------------
AcousticModelRef getAm(const Core::Configuration& config) {
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()));
AcousticModelRef getAm(const Core::Configuration& config, bool useMixture) {
if (useMixture)
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()));
else
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()), Am::AcousticModel::noEmissions);
}

ScaledLanguageModelRef getLm(const Core::Configuration& config) {
Expand Down
2 changes: 1 addition & 1 deletion src/Flf/SegmentwiseSpeechProcessor.hh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef Core::Ref<Lm::LanguageModel> LanguageModelRef;
typedef Core::Ref<Lm::ScaledLanguageModel> ScaledLanguageModelRef;
typedef Core::Ref<Speech::ModelCombination> ModelCombinationRef;

AcousticModelRef getAm(const Core::Configuration& config);
AcousticModelRef getAm(const Core::Configuration& config, bool useMixture=true);
ScaledLanguageModelRef getLm(const Core::Configuration& config);
ModelCombinationRef getModelCombination(const Core::Configuration& config, AcousticModelRef acousticModel, ScaledLanguageModelRef languageModel = ScaledLanguageModelRef());

Expand Down
50 changes: 39 additions & 11 deletions src/Lattice/Lattice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,32 @@ Speech::TimeframeIndex WordLattice::maximumTime() const {
return d.getMaximumTime();
}

StandardWordLattice::StandardWordLattice(Core::Ref<const Bliss::Lexicon> lexicon) {
StandardWordLattice::StandardWordLattice(Core::Ref<const Bliss::Lexicon> lexicon,
AlphabetType alphabetType)
: alphabetType_(alphabetType) {
parts_.addChoice("acoustic", 0);
parts_.addChoice("lm", 1);

acoustic_ = Core::ref(new Fsa::StaticAutomaton);
acoustic_->setType(Fsa::TypeAcceptor);
acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
acoustic_->setSemiring(Fsa::TropicalSemiring);
lm_ = Core::ref(new Fsa::StaticAutomaton);

lm_ = Core::ref(new Fsa::StaticAutomaton);
acoustic_->setType(Fsa::TypeAcceptor);
lm_->setType(Fsa::TypeAcceptor);
lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());

switch (alphabetType_) {
case AlphabetType::LemmaAlphabet:
acoustic_->setInputAlphabet(lexicon->lemmaAlphabet());
lm_->setInputAlphabet(lexicon->lemmaAlphabet());
break;
case AlphabetType::LemmaPronunciationAlphabet:
acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
break;
default:
defect();
}

acoustic_->setSemiring(Fsa::TropicalSemiring);
lm_->setSemiring(Fsa::TropicalSemiring);

fsas_.push_back(acoustic_);
Expand All @@ -226,17 +240,31 @@ Fsa::State* StandardWordLattice::newState() {
return acoustic_->newState();
}

void StandardWordLattice::newArc(
Fsa::State* source,
Fsa::State* target,
Fsa::LabelId id,
Speech::Score acoustic, Speech::Score lm) {
source -> newArc(target->id(), Fsa::Weight(acoustic), id);
lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), id);
}

void StandardWordLattice::newArc(
Fsa::State* source,
Fsa::State* target,
const Bliss::LemmaPronunciation* lemmaPronunciation,
Speech::Score acoustic, Speech::Score lm) {
source->newArc(
target->id(),
Fsa::Weight(acoustic),
(lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon);
Fsa::LabelId id = lemmaPronunciation ? lemmaPronunciation->id() : Fsa::Epsilon;
newArc(source, target, id, acoustic, lm);
}

lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon);
void StandardWordLattice::newArc(
Fsa::State* source,
Fsa::State* target,
const Bliss::Lemma* lemma,
Speech::Score acoustic, Speech::Score lm) {
Fsa::LabelId id = lemma ? lemma->id() : Fsa::Epsilon;
newArc(source, target, id, acoustic, lm);
}

void StandardWordLattice::addAcyclicProperty() {
Expand Down
19 changes: 18 additions & 1 deletion src/Lattice/Lattice.hh
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,24 @@ public:
* near future, when the WordLattice interface is refactored.
*/
class StandardWordLattice : public WordLattice {
public:
enum class AlphabetType {
LemmaAlphabet = 1,
LemmaPronunciationAlphabet = 2
};

private:
Core::Ref<Fsa::StaticAutomaton> acoustic_, lm_;
Fsa::State * initialState_, *finalState_;
AlphabetType alphabetType_;

void newArc(Fsa::State* source,
Fsa::State* target,
Fsa::LabelId id,
Speech::Score acoustic, Speech::Score lm);

public:
StandardWordLattice(Bliss::LexiconRef);
StandardWordLattice(Bliss::LexiconRef, AlphabetType alphabetType = AlphabetType::LemmaAlphabet);

Fsa::State* newState();
Fsa::State* initialState() {
Expand All @@ -288,6 +300,11 @@ public:
const Bliss::LemmaPronunciation*,
Speech::Score acoustic, Speech::Score lm);

void newArc(Fsa::State *source,
Fsa::State *target,
const Bliss::Lemma*,
Speech::Score acoustic, Speech::Score lm);

void addAcyclicProperty();
};

Expand Down
18 changes: 18 additions & 0 deletions src/Nn/LabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "LabelScorer.hh"

using namespace Nn;
32 changes: 32 additions & 0 deletions src/Nn/LabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LABEL_SCORER_HH
#define LABEL_SCORER_HH

#include <Core/Component.hh>
#include <Core/ReferenceCounting.hh>

namespace Nn {


// base class of models for label scoring (basic supports except scoring)
class LabelScorer : public virtual Core::Component,
public Core::ReferenceCounted {
};

} // namespace Nn

#endif
11 changes: 11 additions & 0 deletions src/Nn/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ ifdef MODULE_PYTHON
LIBSPRINTNN_O += $(OBJDIR)/PythonLayer.o
endif

ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LIBSPRINTNN_O += $(OBJDIR)/LabelScorer.o
ifdef MODULE_TENSORFLOW
LIBSPRINTNN_O += $(OBJDIR)/TFLabelScorer.o

CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(TF_LDFLAGS)
CHECK_O += ../Tensorflow/libSprintTensorflow.$(a)
endif
endif

# -----------------------------------------------------------------------------
all: $(TARGETS)

Expand Down
17 changes: 17 additions & 0 deletions src/Nn/Module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
#include "PythonFeatureScorer.hh"
#endif

#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
#include "LabelScorer.hh"
#ifdef MODULE_TENSORFLOW
#include "TFLabelScorer.hh"
#endif
#endif

using namespace Nn;

Module_::Module_()
Expand Down Expand Up @@ -80,3 +87,13 @@ Core::FormatSet& Module_::formats() {
}
return *formats_;
}


Core::Ref<LabelScorer> Module_::createLabelScorer(const Core::Configuration& config) const {
#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LabelScorer* labelScorer = nullptr;
return Core::ref(labelScorer);
#else
Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!");
#endif
}
4 changes: 4 additions & 0 deletions src/Nn/Module.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class FormatSet;

namespace Nn {

class LabelScorer;

class Module_ {
private:
Core::FormatSet* formats_;
Expand All @@ -47,6 +49,8 @@ public:
/** Set of file format class.
*/
Core::FormatSet& formats();

Core::Ref<LabelScorer> createLabelScorer(const Core::Configuration& config) const;
};

typedef Core::SingletonHolder<Module_> Module;
Expand Down
18 changes: 18 additions & 0 deletions src/Nn/TFLabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "TFLabelScorer.hh"

using namespace Nn;
27 changes: 27 additions & 0 deletions src/Nn/TFLabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TF_LABEL_SCORER_HH
#define TF_LABEL_SCORER_HH

#include "LabelScorer.hh"

namespace Nn {


} // namesapce

#endif

2 changes: 1 addition & 1 deletion src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ Core::Ref<const LatticeAdaptor> AdvancedTreeSearchManager::buildLatticeForTrace(
wordBoundaries->set((Fsa::StateId) /* WARNING: this looses a few bits */ (long)result->initialState(), Lattice::WordBoundary(0));
Fsa::State* finalState = result->newState();
wordBoundaries->set(finalState->id(), Lattice::WordBoundary(time_));
result->newArc(result->initialState(), finalState, 0, 0, 0);
result->newArc(result->initialState(), finalState, static_cast<const Bliss::LemmaPronunciation*>(nullptr), 0, 0);
result->setWordBoundaries(wordBoundaries);
result->addAcyclicProperty();
return Core::ref(new Lattice::WordLatticeAdaptor(result));
Expand Down
36 changes: 36 additions & 0 deletions src/Search/GenericSeq2SeqTreeSearch/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!gmake

TOPDIR = ../../..

include $(TOPDIR)/Makefile.cfg

# -----------------------------------------------------------------------------

SUBDIRS =
TARGETS = libSprintGenericSeq2SeqTreeSearch.$(a)

LIBSPRINTSEQ2SEQTREESEARCH_O = $(OBJDIR)/Seq2SeqTreeSearch.o \
$(OBJDIR)/Seq2SeqAligner.o

ifeq ($(OS),darwin)
CCFLAGS += -fexceptions
endif

# These flags make the compilation slow, but are required to inline some critical functions in SearchSpace (copied from AdvancedTreeSearch)
CCFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400
CXXFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400

# -----------------------------------------------------------------------------

all: $(TARGETS)

libSprintGenericSeq2SeqTreeSearch.$(a): $(LIBSPRINTSEQ2SEQTREESEARCH_O)
$(MAKELIB) $@ $^

check$(exe): $(CHECK_O)
$(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS)

include $(TOPDIR)/Rules.make

sinclude $(LIBSPRINTSEQ2SEQTREESEARCH_O:.o=.d)
include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O)))
18 changes: 18 additions & 0 deletions src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "Seq2SeqAligner.hh"

using namespace Search;
Loading

0 comments on commit 3c93cdb

Please sign in to comment.