Skip to content

Commit

Permalink
missing stateless NLM evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhouW321 committed Jan 24, 2024
1 parent ed5efcc commit 7842520
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/Lm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ LIBSPRINTLM_O += $(OBJDIR)/QuantizedCompressedVectorFactory.o
LIBSPRINTLM_O += $(OBJDIR)/ReducedPrecisionCompressedVectorFactory.o
LIBSPRINTLM_O += $(OBJDIR)/TransformerStateManager.o
LIBSPRINTLM_O += $(OBJDIR)/TFRecurrentLanguageModel.o
LIBSPRINTLM_O += $(OBJDIR)/TFSimpleTransformerLm.o
#MODF DummyCompressedVectorFactory.hh
#MODF SoftmaxAdapter.hh
#MODF StateManager.hh
Expand Down
7 changes: 6 additions & 1 deletion src/Lm/Module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#endif
#ifdef MODULE_LM_TFRNN
#include "TFRecurrentLanguageModel.hh"
#include "TFSimpleTransformerLm.hh"
#endif
#include "CombineLm.hh"

Expand All @@ -55,7 +56,8 @@ enum LanguageModelType {
lmTypeCombine,
lmTypeTFRNN,
lmTypeCheatingSegment,
lmTypeSimpleHistory
lmTypeSimpleHistory,
lmTypeTFSimpleTransformer
};
}

Expand All @@ -69,6 +71,8 @@ const Core::Choice Module_::lmTypeChoice(
"tfrnn", lmTypeTFRNN,
"cheating-segment", lmTypeCheatingSegment,
"simple-history", lmTypeSimpleHistory,
"simple-transformer", lmTypeTFSimpleTransformer, // backwards compatibility
"tf-simple-transformer", lmTypeTFSimpleTransformer,
Core::Choice::endMark());

const Core::ParameterChoice Module_::lmTypeParam(
Expand Down Expand Up @@ -97,6 +101,7 @@ Core::Ref<LanguageModel> Module_::createLanguageModel(
case lmTypeCombine: result = Core::ref(new CombineLanguageModel(c, l)); break;
#ifdef MODULE_LM_TFRNN
case lmTypeTFRNN: result = Core::ref(new TFRecurrentLanguageModel(c, l)); break;
case lmTypeTFSimpleTransformer: result = Core::ref(new TFSimpleTransformerLm(c, l)); break;
#endif
case lmTypeSimpleHistory: result = Core::ref(new SimpleHistoryLm(c, l)); break;
default:
Expand Down
189 changes: 189 additions & 0 deletions src/Lm/TFSimpleTransformerLm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* author: Wei Zhou
*/

#include "TFSimpleTransformerLm.hh"

using namespace Lm;

const Core::ParameterBool paramTransformOuputLog(
"transform-output-log",
"apply log to tensorflow output",
false);

const Core::ParameterBool paramTransformOuputNegate(
"transform-output-negate",
"negate tensorflow output (after log)",
false);

const Core::ParameterInt paramMaxBatchSize(
"max-batch-size",
"maximum number of histories forwarded in one go",
64, 1);

TFSimpleTransformerLm::TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l) :
Core::Component(c),
Precursor(c, l),
session_(select("session")),
loader_(Tensorflow::Module::instance().createGraphLoader(select("loader"))),
graph_(loader_->load_graph()), // tf::GraphDef, libraries and necessary param names
max_batch_size_(paramMaxBatchSize(config)) {
bool transform_output_log = paramTransformOuputLog(config);
bool transform_output_negate = paramTransformOuputNegate(config);
if (transform_output_log and transform_output_negate) {
output_transform_function_ = [](Score v){ return -std::log(v); };
Core::Application::us()->log() << "apply -log(.) to model output";
} else if ( transform_output_log ) {
output_transform_function_ = [](Score v){ return std::log(v); };
Core::Application::us()->log() << "apply log(.) to model output";
} else if ( transform_output_negate ) {
output_transform_function_ = [](Score v){ return -v; };
Core::Application::us()->log() << "apply -(.) to model output";
}
}

TFSimpleTransformerLm::~TFSimpleTransformerLm() {
startHistory_ = History();
delete historyManager_;
}

// initialization: vocabulary, model graph and start history
void TFSimpleTransformerLm::load() {
loadVocabulary();
// create tf::Session with graph(tf::GraphDef) and default initialization of variables
session_.addGraph(*graph_);
// restore model checkpoint
loader_->initialize(session_);

// hard-coded IO names
Tensorflow::TensorInputMap input_map(select("input-map"));
input_tensor_name = input_map.get_info("word").tensor_name();
input_length_tensor_name = input_map.get_info("word").seq_length_tensor_name();

Tensorflow::TensorOutputMap output_map(select("output-map"));
output_tensor_names_.push_back(output_map.get_info("softmax").tensor_name());

// no state_vars to be handled in this simple version
// Note: model graph should always have the default initial state for each run

// use SimpleScoreHistoryManager for simplicity and flexibility
delete historyManager_;
historyManager_ = new SimpleScoreHistoryManager();
startHistory_ = startHistory();
// TODO compute the scores at init already ?
}

History TFSimpleTransformerLm::startHistory() const {
if (startHistory_.isValid())
return startHistory_;
// once only
Bliss::Token::Id wId = lexicon_mapping_.at(sentenceBeginToken()->id());
verify(wId < num_outputs_);
SimpleScoreHistoryManager* hm = static_cast<SimpleScoreHistoryManager*>(historyManager_);
HistoryDescriptor* nhd = new HistoryDescriptor(wId);
std::pair<SimpleHistoryCache::iterator, bool> result = hm->updateCache(nhd);
verify(result.second); // must be the only one
cacheHashQueue_.push_back(nhd->cacheHash);
return history(nhd);
}

History TFSimpleTransformerLm::extendedHistory(const History& h, Token w) const {
Bliss::Token::Id wId = lexicon_mapping_.at(w->id());
verify(wId < num_outputs_);
SimpleScoreHistoryManager* hm = static_cast<SimpleScoreHistoryManager*>(historyManager_);
const HistoryDescriptor* chd = static_cast<const HistoryDescriptor*>(h.handle());
HistoryDescriptor* nhd = new HistoryDescriptor(chd->tokIdSeq, wId);

std::pair<SimpleHistoryCache::iterator, bool> result = hm->updateCache(nhd);
if (result.second) { // new one
cacheHashQueue_.push_back(nhd->cacheHash);
} else { // use the existing one
delete nhd;
nhd = result.first->second;
}
return history(nhd);
}

Score TFSimpleTransformerLm::score(const History& h, Token w) const {
size_t wId = lexicon_mapping_.at(w->id());
verify( wId < num_outputs_ );
const HistoryDescriptor* chd = static_cast<const HistoryDescriptor*>(h.handle());
if (!chd->scores.empty())
return chd->scores[wId];

HistoryDescriptor* hd = const_cast<HistoryDescriptor*>(chd);
makeBatch(hd);
verify(batch_.size() > 0 && max_batch_len_ > 0);
scoreBatch();
batch_.clear();
max_batch_len_ = 0;

verify(hd->scores.size() >= num_outputs_);
return hd->scores[wId];
}

void TFSimpleTransformerLm::makeBatch(HistoryDescriptor* hd) const {
// sort by length ? general search behavior ensures similar length in the ordered queue
// maybe more important is the score caching to avoid redundant computaton due to pruning
batch_.push_back(hd);
max_batch_len_ = hd->tokIdSeq.size();

const SimpleHistoryCache& cache = static_cast<SimpleScoreHistoryManager*>(historyManager_)->getCache();
while (batch_.size() < max_batch_size_ && !cacheHashQueue_.empty()) {
size_t hash = cacheHashQueue_.front();
cacheHashQueue_.pop_front();
if (cache.count(hash) == 0 || hash == hd->cacheHash)
continue;
HistoryDescriptor* bhd = cache.at(hash);
if (!bhd->scores.empty())
continue;
batch_.push_back(bhd);
if (bhd->tokIdSeq.size() > max_batch_len_)
max_batch_len_ = bhd->tokIdSeq.size();
}
}

void TFSimpleTransformerLm::scoreBatch() const {
// default initializer always 0 ?
Math::FastMatrix<s32> tokMat(batch_.size(), max_batch_len_);
Math::FastVector<s32> lenVec(batch_.size());
for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) {
const TokenIdSequence& tokSeq = batch_[bIdx]->tokIdSeq;
verify( tokSeq.size() <= max_batch_len_ );
lenVec[bIdx] = tokSeq.size();
for (u32 tIdx = 0; tIdx < tokSeq.size(); ++tIdx)
tokMat.at(bIdx, tIdx) = tokSeq[tIdx];
for (u32 tIdx = tokSeq.size(); tIdx < max_batch_len_; ++tIdx)
tokMat.at(bIdx, tIdx) = 0;
}

BatchInput inputs;
BatchOutput outputs;
inputs.emplace_back(std::make_pair(input_tensor_name, Tensorflow::Tensor::create(tokMat)));
inputs.emplace_back(std::make_pair(input_length_tensor_name, Tensorflow::Tensor::create(lenVec)));
// read tensor values should trigger the computation automatically (no state_vars to be updated)
session_.run(inputs, output_tensor_names_, {}, outputs);

// process scores: expect always only the last output position (B,V)
verify(outputs.size() == 1);
for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) {
std::vector<Score>& scores = batch_[bIdx]->scores;
outputs[0].get(bIdx, scores);
if (output_transform_function_)
std::transform(scores.begin(), scores.end(), scores.begin(), output_transform_function_);
}
}

127 changes: 127 additions & 0 deletions src/Lm/TFSimpleTransformerLm.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/** Copyright 2020 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* author: Wei Zhou
*/

#ifndef _LM_SIMPLE_TRANSFORMER_LM_HH
#define _LM_SIMPLE_TRANSFORMER_LM_HH

#include <Tensorflow/GraphLoader.hh>
#include <Tensorflow/Module.hh>
#include <Tensorflow/Session.hh>
#include <Tensorflow/TensorMap.hh>
#include "AbstractNNLanguageModel.hh"
#include "SimpleHistoryLm.hh"
#include <deque>

namespace Lm {

struct SimpleScoreHistory: public SimpleHistory {
// tokSeq and refCount in base
std::vector<Score> scores;
size_t cacheHash;

typedef SimpleHistory Precursor;
SimpleScoreHistory(Bliss::Token::Id tid): Precursor(tid), cacheHash(0) {}
SimpleScoreHistory(const TokenIdSequence& r, Bliss::Token::Id tid) :
Precursor(r, tid), cacheHash(0) {}
};

typedef std::unordered_map<size_t, SimpleScoreHistory*> SimpleHistoryCache;

class SimpleScoreHistoryManager : public SimpleHistoryManager {
protected:
SimpleHistoryCache historyCache_;

public:
SimpleScoreHistoryManager() {}
~SimpleScoreHistoryManager() {
for (SimpleHistoryCache::iterator iter=historyCache_.begin(); iter!=historyCache_.end(); ++iter)
delete (iter->second);
}

void release (HistoryHandle handle) {
const SimpleScoreHistory* sh = static_cast<const SimpleScoreHistory*>(handle);
--(sh->refCount); // mutable
if (sh->refCount == 0) {
historyCache_.erase(sh->cacheHash);
delete sh;
}
}

const SimpleHistoryCache& getCache() const { return historyCache_; }

std::pair<SimpleHistoryCache::iterator, bool> updateCache(SimpleScoreHistory* sh) {
sh->cacheHash = token_id_sequence_hash(sh->tokIdSeq);
return historyCache_.insert(std::make_pair(sh->cacheHash, sh));
}
};

typedef std::vector<std::pair<std::string, Tensorflow::Tensor>> BatchInput;
typedef std::vector<Tensorflow::Tensor> BatchOutput;

// simple TF Transformer LM: mainly for E2E systems with small search space
// trade speed for simplicity: always feed-in full sequence and get last output scores
// Note: slice last position should be done in model graph
class TFSimpleTransformerLm: public AbstractNNLanguageModel {
typedef AbstractNNLanguageModel Precursor;
typedef SimpleScoreHistory HistoryDescriptor;

protected:
// Note: graph related params follow python naming scheme
mutable Tensorflow::Session session_;
std::unique_ptr<Tensorflow::GraphLoader> loader_;
std::unique_ptr<Tensorflow::Graph> graph_;

// should be single input/output tensor
std::string input_tensor_name;
std::string input_length_tensor_name;
std::vector<std::string> output_tensor_names_;

protected:
std::function<Score(Score)> output_transform_function_;
u32 max_batch_size_; // B
mutable u32 max_batch_len_; // T
mutable std::deque<size_t> cacheHashQueue_; // only not-scored history
mutable std::vector<HistoryDescriptor*> batch_;

History startHistory_; // always cached: same scoring

protected:
void load();

// actually no const functions at all for NNLM: just legacy to LM interface
void makeBatch(HistoryDescriptor* hd) const;
void scoreBatch() const;

// cache most recent scored histories to avoid redundant computation due to pruning
// this can be done by the lookahead table caching scheme (just need to hold the history)
// but better reduce cache size for memory

public:
TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l);
~TFSimpleTransformerLm();

// history (no reduction)
History startHistory() const;
History extendedHistory(const History& h, Token w) const;

// scoring
Score score(const History& h, Token w) const;
};

} // namespace Lm

#endif // _LM_SIMPLE_TRANSFORMER_LM_HH

0 comments on commit 7842520

Please sign in to comment.