-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
323 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
/** Copyright 2020 RWTH Aachen University. All rights reserved. | ||
* | ||
* Licensed under the RWTH ASR License (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
* author: Wei Zhou | ||
*/ | ||
|
||
#include "TFSimpleTransformerLm.hh" | ||
|
||
using namespace Lm; | ||
|
||
const Core::ParameterBool paramTransformOuputLog( | ||
"transform-output-log", | ||
"apply log to tensorflow output", | ||
false); | ||
|
||
const Core::ParameterBool paramTransformOuputNegate( | ||
"transform-output-negate", | ||
"negate tensorflow output (after log)", | ||
false); | ||
|
||
const Core::ParameterInt paramMaxBatchSize( | ||
"max-batch-size", | ||
"maximum number of histories forwarded in one go", | ||
64, 1); | ||
|
||
TFSimpleTransformerLm::TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l) : | ||
Core::Component(c), | ||
Precursor(c, l), | ||
session_(select("session")), | ||
loader_(Tensorflow::Module::instance().createGraphLoader(select("loader"))), | ||
graph_(loader_->load_graph()), // tf::GraphDef, libraries and necessary param names | ||
max_batch_size_(paramMaxBatchSize(config)) { | ||
bool transform_output_log = paramTransformOuputLog(config); | ||
bool transform_output_negate = paramTransformOuputNegate(config); | ||
if (transform_output_log and transform_output_negate) { | ||
output_transform_function_ = [](Score v){ return -std::log(v); }; | ||
Core::Application::us()->log() << "apply -log(.) to model output"; | ||
} else if ( transform_output_log ) { | ||
output_transform_function_ = [](Score v){ return std::log(v); }; | ||
Core::Application::us()->log() << "apply log(.) to model output"; | ||
} else if ( transform_output_negate ) { | ||
output_transform_function_ = [](Score v){ return -v; }; | ||
Core::Application::us()->log() << "apply -(.) to model output"; | ||
} | ||
} | ||
|
||
TFSimpleTransformerLm::~TFSimpleTransformerLm() { | ||
startHistory_ = History(); | ||
delete historyManager_; | ||
} | ||
|
||
// initialization: vocabulary, model graph and start history | ||
void TFSimpleTransformerLm::load() { | ||
loadVocabulary(); | ||
// create tf::Session with graph(tf::GraphDef) and default initialization of variables | ||
session_.addGraph(*graph_); | ||
// restore model checkpoint | ||
loader_->initialize(session_); | ||
|
||
// hard-coded IO names | ||
Tensorflow::TensorInputMap input_map(select("input-map")); | ||
input_tensor_name = input_map.get_info("word").tensor_name(); | ||
input_length_tensor_name = input_map.get_info("word").seq_length_tensor_name(); | ||
|
||
Tensorflow::TensorOutputMap output_map(select("output-map")); | ||
output_tensor_names_.push_back(output_map.get_info("softmax").tensor_name()); | ||
|
||
// no state_vars to be handled in this simple version | ||
// Note: model graph should always have the default initial state for each run | ||
|
||
// use SimpleScoreHistoryManager for simplicity and flexibility | ||
delete historyManager_; | ||
historyManager_ = new SimpleScoreHistoryManager(); | ||
startHistory_ = startHistory(); | ||
// TODO compute the scores at init already ? | ||
} | ||
|
||
History TFSimpleTransformerLm::startHistory() const { | ||
if (startHistory_.isValid()) | ||
return startHistory_; | ||
// once only | ||
Bliss::Token::Id wId = lexicon_mapping_.at(sentenceBeginToken()->id()); | ||
verify(wId < num_outputs_); | ||
SimpleScoreHistoryManager* hm = static_cast<SimpleScoreHistoryManager*>(historyManager_); | ||
HistoryDescriptor* nhd = new HistoryDescriptor(wId); | ||
std::pair<SimpleHistoryCache::iterator, bool> result = hm->updateCache(nhd); | ||
verify(result.second); // must be the only one | ||
cacheHashQueue_.push_back(nhd->cacheHash); | ||
return history(nhd); | ||
} | ||
|
||
History TFSimpleTransformerLm::extendedHistory(const History& h, Token w) const { | ||
Bliss::Token::Id wId = lexicon_mapping_.at(w->id()); | ||
verify(wId < num_outputs_); | ||
SimpleScoreHistoryManager* hm = static_cast<SimpleScoreHistoryManager*>(historyManager_); | ||
const HistoryDescriptor* chd = static_cast<const HistoryDescriptor*>(h.handle()); | ||
HistoryDescriptor* nhd = new HistoryDescriptor(chd->tokIdSeq, wId); | ||
|
||
std::pair<SimpleHistoryCache::iterator, bool> result = hm->updateCache(nhd); | ||
if (result.second) { // new one | ||
cacheHashQueue_.push_back(nhd->cacheHash); | ||
} else { // use the existing one | ||
delete nhd; | ||
nhd = result.first->second; | ||
} | ||
return history(nhd); | ||
} | ||
|
||
Score TFSimpleTransformerLm::score(const History& h, Token w) const { | ||
size_t wId = lexicon_mapping_.at(w->id()); | ||
verify( wId < num_outputs_ ); | ||
const HistoryDescriptor* chd = static_cast<const HistoryDescriptor*>(h.handle()); | ||
if (!chd->scores.empty()) | ||
return chd->scores[wId]; | ||
|
||
HistoryDescriptor* hd = const_cast<HistoryDescriptor*>(chd); | ||
makeBatch(hd); | ||
verify(batch_.size() > 0 && max_batch_len_ > 0); | ||
scoreBatch(); | ||
batch_.clear(); | ||
max_batch_len_ = 0; | ||
|
||
verify(hd->scores.size() >= num_outputs_); | ||
return hd->scores[wId]; | ||
} | ||
|
||
void TFSimpleTransformerLm::makeBatch(HistoryDescriptor* hd) const { | ||
// sort by length ? general search behavior ensures similar length in the ordered queue | ||
// maybe more important is the score caching to avoid redundant computaton due to pruning | ||
batch_.push_back(hd); | ||
max_batch_len_ = hd->tokIdSeq.size(); | ||
|
||
const SimpleHistoryCache& cache = static_cast<SimpleScoreHistoryManager*>(historyManager_)->getCache(); | ||
while (batch_.size() < max_batch_size_ && !cacheHashQueue_.empty()) { | ||
size_t hash = cacheHashQueue_.front(); | ||
cacheHashQueue_.pop_front(); | ||
if (cache.count(hash) == 0 || hash == hd->cacheHash) | ||
continue; | ||
HistoryDescriptor* bhd = cache.at(hash); | ||
if (!bhd->scores.empty()) | ||
continue; | ||
batch_.push_back(bhd); | ||
if (bhd->tokIdSeq.size() > max_batch_len_) | ||
max_batch_len_ = bhd->tokIdSeq.size(); | ||
} | ||
} | ||
|
||
void TFSimpleTransformerLm::scoreBatch() const { | ||
// default initializer always 0 ? | ||
Math::FastMatrix<s32> tokMat(batch_.size(), max_batch_len_); | ||
Math::FastVector<s32> lenVec(batch_.size()); | ||
for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) { | ||
const TokenIdSequence& tokSeq = batch_[bIdx]->tokIdSeq; | ||
verify( tokSeq.size() <= max_batch_len_ ); | ||
lenVec[bIdx] = tokSeq.size(); | ||
for (u32 tIdx = 0; tIdx < tokSeq.size(); ++tIdx) | ||
tokMat.at(bIdx, tIdx) = tokSeq[tIdx]; | ||
for (u32 tIdx = tokSeq.size(); tIdx < max_batch_len_; ++tIdx) | ||
tokMat.at(bIdx, tIdx) = 0; | ||
} | ||
|
||
BatchInput inputs; | ||
BatchOutput outputs; | ||
inputs.emplace_back(std::make_pair(input_tensor_name, Tensorflow::Tensor::create(tokMat))); | ||
inputs.emplace_back(std::make_pair(input_length_tensor_name, Tensorflow::Tensor::create(lenVec))); | ||
// read tensor values should trigger the computation automatically (no state_vars to be updated) | ||
session_.run(inputs, output_tensor_names_, {}, outputs); | ||
|
||
// process scores: expect always only the last output position (B,V) | ||
verify(outputs.size() == 1); | ||
for (u32 bIdx = 0; bIdx < batch_.size(); ++bIdx) { | ||
std::vector<Score>& scores = batch_[bIdx]->scores; | ||
outputs[0].get(bIdx, scores); | ||
if (output_transform_function_) | ||
std::transform(scores.begin(), scores.end(), scores.begin(), output_transform_function_); | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/** Copyright 2020 RWTH Aachen University. All rights reserved. | ||
* | ||
* Licensed under the RWTH ASR License (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
* author: Wei Zhou | ||
*/ | ||
|
||
#ifndef _LM_SIMPLE_TRANSFORMER_LM_HH | ||
#define _LM_SIMPLE_TRANSFORMER_LM_HH | ||
|
||
#include <Tensorflow/GraphLoader.hh> | ||
#include <Tensorflow/Module.hh> | ||
#include <Tensorflow/Session.hh> | ||
#include <Tensorflow/TensorMap.hh> | ||
#include "AbstractNNLanguageModel.hh" | ||
#include "SimpleHistoryLm.hh" | ||
#include <deque> | ||
|
||
namespace Lm { | ||
|
||
struct SimpleScoreHistory: public SimpleHistory { | ||
// tokSeq and refCount in base | ||
std::vector<Score> scores; | ||
size_t cacheHash; | ||
|
||
typedef SimpleHistory Precursor; | ||
SimpleScoreHistory(Bliss::Token::Id tid): Precursor(tid), cacheHash(0) {} | ||
SimpleScoreHistory(const TokenIdSequence& r, Bliss::Token::Id tid) : | ||
Precursor(r, tid), cacheHash(0) {} | ||
}; | ||
|
||
typedef std::unordered_map<size_t, SimpleScoreHistory*> SimpleHistoryCache; | ||
|
||
class SimpleScoreHistoryManager : public SimpleHistoryManager { | ||
protected: | ||
SimpleHistoryCache historyCache_; | ||
|
||
public: | ||
SimpleScoreHistoryManager() {} | ||
~SimpleScoreHistoryManager() { | ||
for (SimpleHistoryCache::iterator iter=historyCache_.begin(); iter!=historyCache_.end(); ++iter) | ||
delete (iter->second); | ||
} | ||
|
||
void release (HistoryHandle handle) { | ||
const SimpleScoreHistory* sh = static_cast<const SimpleScoreHistory*>(handle); | ||
--(sh->refCount); // mutable | ||
if (sh->refCount == 0) { | ||
historyCache_.erase(sh->cacheHash); | ||
delete sh; | ||
} | ||
} | ||
|
||
const SimpleHistoryCache& getCache() const { return historyCache_; } | ||
|
||
std::pair<SimpleHistoryCache::iterator, bool> updateCache(SimpleScoreHistory* sh) { | ||
sh->cacheHash = token_id_sequence_hash(sh->tokIdSeq); | ||
return historyCache_.insert(std::make_pair(sh->cacheHash, sh)); | ||
} | ||
}; | ||
|
||
typedef std::vector<std::pair<std::string, Tensorflow::Tensor>> BatchInput; | ||
typedef std::vector<Tensorflow::Tensor> BatchOutput; | ||
|
||
// simple TF Transformer LM: mainly for E2E systems with small search space | ||
// trade speed for simplicity: always feed-in full sequence and get last output scores | ||
// Note: slice last position should be done in model graph | ||
class TFSimpleTransformerLm: public AbstractNNLanguageModel { | ||
typedef AbstractNNLanguageModel Precursor; | ||
typedef SimpleScoreHistory HistoryDescriptor; | ||
|
||
protected: | ||
// Note: graph related params follow python naming scheme | ||
mutable Tensorflow::Session session_; | ||
std::unique_ptr<Tensorflow::GraphLoader> loader_; | ||
std::unique_ptr<Tensorflow::Graph> graph_; | ||
|
||
// should be single input/output tensor | ||
std::string input_tensor_name; | ||
std::string input_length_tensor_name; | ||
std::vector<std::string> output_tensor_names_; | ||
|
||
protected: | ||
std::function<Score(Score)> output_transform_function_; | ||
u32 max_batch_size_; // B | ||
mutable u32 max_batch_len_; // T | ||
mutable std::deque<size_t> cacheHashQueue_; // only not-scored history | ||
mutable std::vector<HistoryDescriptor*> batch_; | ||
|
||
History startHistory_; // always cached: same scoring | ||
|
||
protected: | ||
void load(); | ||
|
||
// actually no const functions at all for NNLM: just legacy to LM interface | ||
void makeBatch(HistoryDescriptor* hd) const; | ||
void scoreBatch() const; | ||
|
||
// cache most recent scored histories to avoid redundant computation due to pruning | ||
// this can be done by the lookahead table caching scheme (just need to hold the history) | ||
// but better reduce cache size for memory | ||
|
||
public: | ||
TFSimpleTransformerLm(const Core::Configuration& c, Bliss::LexiconRef l); | ||
~TFSimpleTransformerLm(); | ||
|
||
// history (no reduction) | ||
History startHistory() const; | ||
History extendedHistory(const History& h, Token w) const; | ||
|
||
// scoring | ||
Score score(const History& h, Token w) const; | ||
}; | ||
|
||
} // namespace Lm | ||
|
||
#endif // _LM_SIMPLE_TRANSFORMER_LM_HH |