Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Seq2Seq Decoder preparation #49

Merged
merged 3 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Modules.make
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ MODULES += MODULE_THEANO_INTERFACE
MODULES += MODULE_PYTHON

# ****** OpenFst ******
MODULES += MODULE_OPENFST
MODULES += MODULE_OPENFST

# ****** Search ******
MODULES += MODULE_SEARCH_MBR
MODULES += MODULE_SEARCH_WFST
MODULES += MODULE_SEARCH_MBR
MODULES += MODULE_SEARCH_WFST
MODULES += MODULE_SEARCH_LINEAR
MODULES += MODULE_ADVANCED_TREE_SEARCH
MODULES += MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH

# ****** Signal ******
MODULES += MODULE_SIGNAL_GAMMATONE
Expand Down Expand Up @@ -151,3 +152,6 @@ endif
ifdef MODULE_ADVANCED_TREE_SEARCH
LIBS_SEARCH += src/Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a)
endif
ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LIBS_SEARCH += src/Search/GenericSeq2SeqTreeSearch/libSprintGenericSeq2SeqTreeSearch.$(a)
endif
7 changes: 5 additions & 2 deletions src/Flf/SegmentwiseSpeechProcessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
namespace Flf {

// -------------------------------------------------------------------------
AcousticModelRef getAm(const Core::Configuration& config) {
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()));
AcousticModelRef getAm(const Core::Configuration& config, bool useMixture) {
if (useMixture)
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()));
else
return Am::Module::instance().createAcousticModel(config, Bliss::LexiconRef(Lexicon::us()), Am::AcousticModel::noEmissions);
}

ScaledLanguageModelRef getLm(const Core::Configuration& config) {
Expand Down
2 changes: 1 addition & 1 deletion src/Flf/SegmentwiseSpeechProcessor.hh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef Core::Ref<Lm::LanguageModel> LanguageModelRef;
typedef Core::Ref<Lm::ScaledLanguageModel> ScaledLanguageModelRef;
typedef Core::Ref<Speech::ModelCombination> ModelCombinationRef;

AcousticModelRef getAm(const Core::Configuration& config);
AcousticModelRef getAm(const Core::Configuration& config, bool useMixture=true);
ScaledLanguageModelRef getLm(const Core::Configuration& config);
ModelCombinationRef getModelCombination(const Core::Configuration& config, AcousticModelRef acousticModel, ScaledLanguageModelRef languageModel = ScaledLanguageModelRef());

Expand Down
50 changes: 39 additions & 11 deletions src/Lattice/Lattice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,32 @@ Speech::TimeframeIndex WordLattice::maximumTime() const {
return d.getMaximumTime();
}

StandardWordLattice::StandardWordLattice(Core::Ref<const Bliss::Lexicon> lexicon) {
StandardWordLattice::StandardWordLattice(Core::Ref<const Bliss::Lexicon> lexicon,
AlphabetType alphabetType)
: alphabetType_(alphabetType) {
parts_.addChoice("acoustic", 0);
parts_.addChoice("lm", 1);

acoustic_ = Core::ref(new Fsa::StaticAutomaton);
acoustic_->setType(Fsa::TypeAcceptor);
acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
acoustic_->setSemiring(Fsa::TropicalSemiring);
lm_ = Core::ref(new Fsa::StaticAutomaton);

lm_ = Core::ref(new Fsa::StaticAutomaton);
acoustic_->setType(Fsa::TypeAcceptor);
lm_->setType(Fsa::TypeAcceptor);
lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());

switch (alphabetType_) {
case AlphabetType::LemmaAlphabet:
acoustic_->setInputAlphabet(lexicon->lemmaAlphabet());
lm_->setInputAlphabet(lexicon->lemmaAlphabet());
break;
case AlphabetType::LemmaPronunciationAlphabet:
acoustic_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
lm_->setInputAlphabet(lexicon->lemmaPronunciationAlphabet());
break;
default:
defect();
}

acoustic_->setSemiring(Fsa::TropicalSemiring);
lm_->setSemiring(Fsa::TropicalSemiring);

fsas_.push_back(acoustic_);
Expand All @@ -226,17 +240,31 @@ Fsa::State* StandardWordLattice::newState() {
return acoustic_->newState();
}

void StandardWordLattice::newArc(
Fsa::State* source,
Fsa::State* target,
Fsa::LabelId id,
Speech::Score acoustic, Speech::Score lm) {
source -> newArc(target->id(), Fsa::Weight(acoustic), id);
lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), id);
}

void StandardWordLattice::newArc(
Fsa::State* source,
Fsa::State* target,
const Bliss::LemmaPronunciation* lemmaPronunciation,
Speech::Score acoustic, Speech::Score lm) {
source->newArc(
target->id(),
Fsa::Weight(acoustic),
(lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon);
Fsa::LabelId id = lemmaPronunciation ? lemmaPronunciation->id() : Fsa::Epsilon;
newArc(source, target, id, acoustic, lm);
}

lm_->state(source->id())->newArc(target->id(), Fsa::Weight(lm), (lemmaPronunciation) ? lemmaPronunciation->id() : Fsa::Epsilon);
void StandardWordLattice::newArc(
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
Fsa::State* source,
Fsa::State* target,
const Bliss::Lemma* lemma,
Speech::Score acoustic, Speech::Score lm) {
Fsa::LabelId id = lemma ? lemma->id() : Fsa::Epsilon;
newArc(source, target, id, acoustic, lm);
}

void StandardWordLattice::addAcyclicProperty() {
Expand Down
19 changes: 18 additions & 1 deletion src/Lattice/Lattice.hh
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,24 @@ public:
* near future, when the WordLattice interface is refactored.
*/
class StandardWordLattice : public WordLattice {
public:
enum class AlphabetType {
LemmaAlphabet = 1,
LemmaPronunciationAlphabet = 2
};

private:
Core::Ref<Fsa::StaticAutomaton> acoustic_, lm_;
Fsa::State * initialState_, *finalState_;
AlphabetType alphabetType_;

void newArc(Fsa::State* source,
Fsa::State* target,
Fsa::LabelId id,
Speech::Score acoustic, Speech::Score lm);

public:
StandardWordLattice(Bliss::LexiconRef);
StandardWordLattice(Bliss::LexiconRef, AlphabetType alphabetType = AlphabetType::LemmaAlphabet);

Fsa::State* newState();
Fsa::State* initialState() {
Expand All @@ -288,6 +300,11 @@ public:
const Bliss::LemmaPronunciation*,
Speech::Score acoustic, Speech::Score lm);

void newArc(Fsa::State *source,
Fsa::State *target,
const Bliss::Lemma*,
Speech::Score acoustic, Speech::Score lm);

void addAcyclicProperty();
};

Expand Down
18 changes: 18 additions & 0 deletions src/Nn/LabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "LabelScorer.hh"

using namespace Nn;
32 changes: 32 additions & 0 deletions src/Nn/LabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LABEL_SCORER_HH
#define LABEL_SCORER_HH

#include <Core/Component.hh>
#include <Core/ReferenceCounting.hh>

namespace Nn {


// base class of models for label scoring (basic supports except scoring)
class LabelScorer : public virtual Core::Component,
public Core::ReferenceCounted {
};

} // namespace Nn

#endif
11 changes: 11 additions & 0 deletions src/Nn/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ ifdef MODULE_PYTHON
LIBSPRINTNN_O += $(OBJDIR)/PythonLayer.o
endif

ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LIBSPRINTNN_O += $(OBJDIR)/LabelScorer.o
ifdef MODULE_TENSORFLOW
LIBSPRINTNN_O += $(OBJDIR)/TFLabelScorer.o

CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(TF_LDFLAGS)
CHECK_O += ../Tensorflow/libSprintTensorflow.$(a)
endif
endif

# -----------------------------------------------------------------------------
all: $(TARGETS)

Expand Down
17 changes: 17 additions & 0 deletions src/Nn/Module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
#include "PythonFeatureScorer.hh"
#endif

#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
#include "LabelScorer.hh"
#ifdef MODULE_TENSORFLOW
#include "TFLabelScorer.hh"
#endif
#endif

using namespace Nn;

Module_::Module_()
Expand Down Expand Up @@ -80,3 +87,13 @@ Core::FormatSet& Module_::formats() {
}
return *formats_;
}


Core::Ref<LabelScorer> Module_::createLabelScorer(const Core::Configuration& config) const {
#ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
LabelScorer* labelScorer = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is temporary. Later we should make sure that the current version is created based on a type prameter that currently only has one option tf-label-scorer.

return Core::ref(labelScorer);
#else
Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!");
#endif
}
4 changes: 4 additions & 0 deletions src/Nn/Module.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class FormatSet;

namespace Nn {

class LabelScorer;

class Module_ {
private:
Core::FormatSet* formats_;
Expand All @@ -47,6 +49,8 @@ public:
/** Set of file format class.
*/
Core::FormatSet& formats();

Core::Ref<LabelScorer> createLabelScorer(const Core::Configuration& config) const;
};

typedef Core::SingletonHolder<Module_> Module;
Expand Down
18 changes: 18 additions & 0 deletions src/Nn/TFLabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "TFLabelScorer.hh"

using namespace Nn;
27 changes: 27 additions & 0 deletions src/Nn/TFLabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TF_LABEL_SCORER_HH
#define TF_LABEL_SCORER_HH

#include "LabelScorer.hh"

namespace Nn {


} // namesapce

#endif

2 changes: 1 addition & 1 deletion src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ Core::Ref<const LatticeAdaptor> AdvancedTreeSearchManager::buildLatticeForTrace(
wordBoundaries->set((Fsa::StateId) /* WARNING: this looses a few bits */ (long)result->initialState(), Lattice::WordBoundary(0));
Fsa::State* finalState = result->newState();
wordBoundaries->set(finalState->id(), Lattice::WordBoundary(time_));
result->newArc(result->initialState(), finalState, 0, 0, 0);
result->newArc(result->initialState(), finalState, static_cast<const Bliss::LemmaPronunciation*>(nullptr), 0, 0);
result->setWordBoundaries(wordBoundaries);
result->addAcyclicProperty();
return Core::ref(new Lattice::WordLatticeAdaptor(result));
Expand Down
36 changes: 36 additions & 0 deletions src/Search/GenericSeq2SeqTreeSearch/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!gmake

TOPDIR = ../../..

include $(TOPDIR)/Makefile.cfg

# -----------------------------------------------------------------------------

SUBDIRS =
TARGETS = libSprintGenericSeq2SeqTreeSearch.$(a)

LIBSPRINTSEQ2SEQTREESEARCH_O = $(OBJDIR)/Seq2SeqTreeSearch.o \
$(OBJDIR)/Seq2SeqAligner.o

ifeq ($(OS),darwin)
CCFLAGS += -fexceptions
endif

# These flags make the compilation slow, but are required to inline some critical functions in SearchSpace (copied from AdvancedTreeSearch)
CCFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400
CXXFLAGS += -Wno-sign-compare -Winline --param max-inline-insns-auto=10000 --param max-inline-insns-single=10000 --param large-function-growth=25000 --param inline-unit-growth=400

# -----------------------------------------------------------------------------

all: $(TARGETS)

libSprintGenericSeq2SeqTreeSearch.$(a): $(LIBSPRINTSEQ2SEQTREESEARCH_O)
$(MAKELIB) $@ $^

check$(exe): $(CHECK_O)
$(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS)

include $(TOPDIR)/Rules.make

sinclude $(LIBSPRINTSEQ2SEQTREESEARCH_O:.o=.d)
include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O)))
18 changes: 18 additions & 0 deletions src/Search/GenericSeq2SeqTreeSearch/Seq2SeqAligner.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "Seq2SeqAligner.hh"

using namespace Search;
Loading