From 78425206d40b6cee47560811d71616479ad20f54 Mon Sep 17 00:00:00 2001 From: Wei Zhou Date: Wed, 24 Jan 2024 18:09:07 +0100 Subject: [PATCH] missing stateless NLM evaluation --- src/Lm/Makefile | 1 + src/Lm/Module.cc | 7 +- src/Lm/TFSimpleTransformerLm.cc | 189 ++++++++++++++++++++++++++++++++ src/Lm/TFSimpleTransformerLm.hh | 127 +++++++++++++++++++++ 4 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 src/Lm/TFSimpleTransformerLm.cc create mode 100644 src/Lm/TFSimpleTransformerLm.hh diff --git a/src/Lm/Makefile b/src/Lm/Makefile index 39883305b..cbf1a30d1 100644 --- a/src/Lm/Makefile +++ b/src/Lm/Makefile @@ -50,6 +50,7 @@ LIBSPRINTLM_O += $(OBJDIR)/QuantizedCompressedVectorFactory.o LIBSPRINTLM_O += $(OBJDIR)/ReducedPrecisionCompressedVectorFactory.o LIBSPRINTLM_O += $(OBJDIR)/TransformerStateManager.o LIBSPRINTLM_O += $(OBJDIR)/TFRecurrentLanguageModel.o +LIBSPRINTLM_O += $(OBJDIR)/TFSimpleTransformerLm.o #MODF DummyCompressedVectorFactory.hh #MODF SoftmaxAdapter.hh #MODF StateManager.hh diff --git a/src/Lm/Module.cc b/src/Lm/Module.cc index 845a5f789..810e8929c 100644 --- a/src/Lm/Module.cc +++ b/src/Lm/Module.cc @@ -31,6 +31,7 @@ #endif #ifdef MODULE_LM_TFRNN #include "TFRecurrentLanguageModel.hh" +#include "TFSimpleTransformerLm.hh" #endif #include "CombineLm.hh" @@ -55,7 +56,8 @@ enum LanguageModelType { lmTypeCombine, lmTypeTFRNN, lmTypeCheatingSegment, - lmTypeSimpleHistory + lmTypeSimpleHistory, + lmTypeTFSimpleTransformer }; } @@ -69,6 +71,8 @@ const Core::Choice Module_::lmTypeChoice( "tfrnn", lmTypeTFRNN, "cheating-segment", lmTypeCheatingSegment, "simple-history", lmTypeSimpleHistory, + "simple-transformer", lmTypeTFSimpleTransformer, // backwards compatibility + "tf-simple-transformer", lmTypeTFSimpleTransformer, Core::Choice::endMark()); const Core::ParameterChoice Module_::lmTypeParam( @@ -97,6 +101,7 @@ Core::Ref Module_::createLanguageModel( case lmTypeCombine: result = Core::ref(new CombineLanguageModel(c, l)); break; #ifdef MODULE_LM_TFRNN case lmTypeTFRNN: result = Core::ref(new TFRecurrentLanguageModel(c, l)); break; + case lmTypeTFSimpleTransformer: result = Core::ref(new TFSimpleTransformerLm(c, l)); break; #endif case lmTypeSimpleHistory: result = Core::ref(new SimpleHistoryLm(c, l)); break; default: diff --git a/src/Lm/TFSimpleTransformerLm.cc b/src/Lm/TFSimpleTransformerLm.cc new file mode 100644 index 000000000..b434e4aef --- /dev/null +++ b/src/Lm/TFSimpleTransformerLm.cc @@ -0,0 +1,189 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#include "TFSimpleTransformerLm.hh" + +using namespace Lm; + +const Core::ParameterBool paramTransformOuputLog( + "transform-output-log", + "apply log to tensorflow output", + false); + +const Core::ParameterBool paramTransformOuputNegate( + "transform-output-negate", + "negate tensorflow output (after log)", + false); + +const Core::ParameterInt paramMaxBatchSize( + "max-batch-size", + "maximum number of histories forwarded in one go", + 64, 1); + +TFSimpleTransformerLm::TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l) : + Core::Component(c), + Precursor(c, l), + session_(select("session")), + loader_(Tensorflow::Module::instance().createGraphLoader(select("loader"))), + graph_(loader_->load_graph()), // tf::GraphDef, libraries and necessary param names + max_batch_size_(paramMaxBatchSize(config)) { + bool transform_output_log = paramTransformOuputLog(config); + bool transform_output_negate = paramTransformOuputNegate(config); + if (transform_output_log and transform_output_negate) { + output_transform_function_ = [](Score v){ return -std::log(v); }; + Core::Application::us()->log() << "apply -log(.) to model output"; + } else if ( transform_output_log ) { + output_transform_function_ = [](Score v){ return std::log(v); }; + Core::Application::us()->log() << "apply log(.) to model output"; + } else if ( transform_output_negate ) { + output_transform_function_ = [](Score v){ return -v; }; + Core::Application::us()->log() << "apply -(.) to model output"; + } +} + +TFSimpleTransformerLm::~TFSimpleTransformerLm() { + startHistory_ = History(); + delete historyManager_; +} + +// initialization: vocabulary, model graph and start history +void TFSimpleTransformerLm::load() { + loadVocabulary(); + // create tf::Session with graph(tf::GraphDef) and default initialization of variables + session_.addGraph(*graph_); + // restore model checkpoint + loader_->initialize(session_); + + // hard-coded IO names + Tensorflow::TensorInputMap input_map(select("input-map")); + input_tensor_name = input_map.get_info("word").tensor_name(); + input_length_tensor_name = input_map.get_info("word").seq_length_tensor_name(); + + Tensorflow::TensorOutputMap output_map(select("output-map")); + output_tensor_names_.push_back(output_map.get_info("softmax").tensor_name()); + + // no state_vars to be handled in this simple version + // Note: model graph should always have the default initial state for each run + + // use SimpleScoreHistoryManager for simplicity and flexibility + delete historyManager_; + historyManager_ = new SimpleScoreHistoryManager(); + startHistory_ = startHistory(); + // TODO compute the scores at init already ? +} + +History TFSimpleTransformerLm::startHistory() const { + if (startHistory_.isValid()) + return startHistory_; + // once only + Bliss::Token::Id wId = lexicon_mapping_.at(sentenceBeginToken()->id()); + verify(wId < num_outputs_); + SimpleScoreHistoryManager* hm = static_cast(historyManager_); + HistoryDescriptor* nhd = new HistoryDescriptor(wId); + std::pair result = hm->updateCache(nhd); + verify(result.second); // must be the only one + cacheHashQueue_.push_back(nhd->cacheHash); + return history(nhd); +} + +History TFSimpleTransformerLm::extendedHistory(const History& h, Token w) const { + Bliss::Token::Id wId = lexicon_mapping_.at(w->id()); + verify(wId < num_outputs_); + SimpleScoreHistoryManager* hm = static_cast(historyManager_); + const HistoryDescriptor* chd = static_cast(h.handle()); + HistoryDescriptor* nhd = new HistoryDescriptor(chd->tokIdSeq, wId); + + std::pair result = hm->updateCache(nhd); + if (result.second) { // new one + cacheHashQueue_.push_back(nhd->cacheHash); + } else { // use the existing one + delete nhd; + nhd = result.first->second; + } + return history(nhd); +} + +Score TFSimpleTransformerLm::score(const History& h, Token w) const { + size_t wId = lexicon_mapping_.at(w->id()); + verify( wId < num_outputs_ ); + const HistoryDescriptor* chd = static_cast(h.handle()); + if (!chd->scores.empty()) + return chd->scores[wId]; + + HistoryDescriptor* hd = const_cast(chd); + makeBatch(hd); + verify(batch_.size() > 0 && max_batch_len_ > 0); + scoreBatch(); + batch_.clear(); + max_batch_len_ = 0; + + verify(hd->scores.size() >= num_outputs_); + return hd->scores[wId]; +} + +void TFSimpleTransformerLm::makeBatch(HistoryDescriptor* hd) const { + // sort by length ? general search behavior ensures similar length in the ordered queue + // maybe more important is the score caching to avoid redundant computaton due to pruning + batch_.push_back(hd); + max_batch_len_ = hd->tokIdSeq.size(); + + const SimpleHistoryCache& cache = static_cast(historyManager_)->getCache(); + while (batch_.size() < max_batch_size_ && !cacheHashQueue_.empty()) { + size_t hash = cacheHashQueue_.front(); + cacheHashQueue_.pop_front(); + if (cache.count(hash) == 0 || hash == hd->cacheHash) + continue; + HistoryDescriptor* bhd = cache.at(hash); + if (!bhd->scores.empty()) + continue; + batch_.push_back(bhd); + if (bhd->tokIdSeq.size() > max_batch_len_) + max_batch_len_ = bhd->tokIdSeq.size(); + } +} + +void TFSimpleTransformerLm::scoreBatch() const { + // default initializer always 0 ? + Math::FastMatrix tokMat(batch_.size(), max_batch_len_); + Math::FastVector lenVec(batch_.size()); + for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) { + const TokenIdSequence& tokSeq = batch_[bIdx]->tokIdSeq; + verify( tokSeq.size() <= max_batch_len_ ); + lenVec[bIdx] = tokSeq.size(); + for (u32 tIdx = 0; tIdx < tokSeq.size(); ++tIdx) + tokMat.at(bIdx, tIdx) = tokSeq[tIdx]; + for (u32 tIdx = tokSeq.size(); tIdx < max_batch_len_; ++tIdx) + tokMat.at(bIdx, tIdx) = 0; + } + + BatchInput inputs; + BatchOutput outputs; + inputs.emplace_back(std::make_pair(input_tensor_name, Tensorflow::Tensor::create(tokMat))); + inputs.emplace_back(std::make_pair(input_length_tensor_name, Tensorflow::Tensor::create(lenVec))); + // read tensor values should trigger the computation automatically (no state_vars to be updated) + session_.run(inputs, output_tensor_names_, {}, outputs); + + // process scores: expect always only the last output position (B,V) + verify(outputs.size() == 1); + for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) { + std::vector& scores = batch_[bIdx]->scores; + outputs[0].get(bIdx, scores); + if (output_transform_function_) + std::transform(scores.begin(), scores.end(), scores.begin(), output_transform_function_); + } +} + diff --git a/src/Lm/TFSimpleTransformerLm.hh b/src/Lm/TFSimpleTransformerLm.hh new file mode 100644 index 000000000..87ff1fc1f --- /dev/null +++ b/src/Lm/TFSimpleTransformerLm.hh @@ -0,0 +1,127 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef _LM_SIMPLE_TRANSFORMER_LM_HH +#define _LM_SIMPLE_TRANSFORMER_LM_HH + +#include +#include +#include +#include +#include "AbstractNNLanguageModel.hh" +#include "SimpleHistoryLm.hh" +#include + +namespace Lm { + +struct SimpleScoreHistory: public SimpleHistory { + // tokSeq and refCount in base + std::vector scores; + size_t cacheHash; + + typedef SimpleHistory Precursor; + SimpleScoreHistory(Bliss::Token::Id tid): Precursor(tid), cacheHash(0) {} + SimpleScoreHistory(const TokenIdSequence& r, Bliss::Token::Id tid) : + Precursor(r, tid), cacheHash(0) {} +}; + +typedef std::unordered_map SimpleHistoryCache; + +class SimpleScoreHistoryManager : public SimpleHistoryManager { + protected: + SimpleHistoryCache historyCache_; + + public: + SimpleScoreHistoryManager() {} + ~SimpleScoreHistoryManager() { + for (SimpleHistoryCache::iterator iter=historyCache_.begin(); iter!=historyCache_.end(); ++iter) + delete (iter->second); + } + + void release (HistoryHandle handle) { + const SimpleScoreHistory* sh = static_cast(handle); + --(sh->refCount); // mutable + if (sh->refCount == 0) { + historyCache_.erase(sh->cacheHash); + delete sh; + } + } + + const SimpleHistoryCache& getCache() const { return historyCache_; } + + std::pair updateCache(SimpleScoreHistory* sh) { + sh->cacheHash = token_id_sequence_hash(sh->tokIdSeq); + return historyCache_.insert(std::make_pair(sh->cacheHash, sh)); + } +}; + +typedef std::vector> BatchInput; +typedef std::vector BatchOutput; + +// simple TF Transformer LM: mainly for E2E systems with small search space +// trade speed for simplicity: always feed-in full sequence and get last output scores +// Note: slice last position should be done in model graph +class TFSimpleTransformerLm: public AbstractNNLanguageModel { + typedef AbstractNNLanguageModel Precursor; + typedef SimpleScoreHistory HistoryDescriptor; + + protected: + // Note: graph related params follow python naming scheme + mutable Tensorflow::Session session_; + std::unique_ptr loader_; + std::unique_ptr graph_; + + // should be single input/output tensor + std::string input_tensor_name; + std::string input_length_tensor_name; + std::vector output_tensor_names_; + + protected: + std::function output_transform_function_; + u32 max_batch_size_; // B + mutable u32 max_batch_len_; // T + mutable std::deque cacheHashQueue_; // only not-scored history + mutable std::vector batch_; + + History startHistory_; // always cached: same scoring + + protected: + void load(); + + // actually no const functions at all for NNLM: just legacy to LM interface + void makeBatch(HistoryDescriptor* hd) const; + void scoreBatch() const; + + // cache most recent scored histories to avoid redundant computation due to pruning + // this can be done by the lookahead table caching scheme (just need to hold the history) + // but better reduce cache size for memory + + public: + TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l); + ~TFSimpleTransformerLm(); + + // history (no reduction) + History startHistory() const; + History extendedHistory(const History& h, Token w) const; + + // scoring + Score score(const History& h, Token w) const; +}; + +} // namespace Lm + +#endif // _LM_SIMPLE_TRANSFORMER_LM_HH