From 51de745c45a9c1693660da6ad3276acba413f214 Mon Sep 17 00:00:00 2001 From: Simon-Berger Date: Mon, 17 Apr 2023 12:22:45 +0200 Subject: [PATCH 1/6] Add LabelScorer code --- src/Nn/LabelHistoryManager.hh | 276 ++++++++ src/Nn/LabelScorer.cc | 348 ++++++++++ src/Nn/LabelScorer.hh | 216 ++++++ src/Nn/Module.cc | 62 +- src/Nn/TFLabelScorer.cc | 1186 +++++++++++++++++++++++++++++++++ src/Nn/TFLabelScorer.hh | 294 +++++++- 6 files changed, 2376 insertions(+), 6 deletions(-) create mode 100644 src/Nn/LabelHistoryManager.hh diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh new file mode 100644 index 00000000..20167d77 --- /dev/null +++ b/src/Nn/LabelHistoryManager.hh @@ -0,0 +1,276 @@ +/** Copyright 2020 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * author: Wei Zhou + */ + +#ifndef LABEL_HISTORY_MANAGER_HH +#define LABEL_HISTORY_MANAGER_HH + +#include +#include +#include + +// boost way of hash merging (we use 0 for special case: mostly initial or N.A.) +// Note: not 100% collision-free, better with additional safety where it's applied +static size_t updateHashKey(size_t hash, size_t update) { + // nothing to update + if (update == 0) + return hash; + if (hash == 0) + return update; + return hash ^ (update + 0x9e3779b9 + (hash << 6) + (hash >> 2)); +} + +namespace Nn { + +class LabelHistory; + +typedef Search::Index LabelIndex; +typedef std::vector LabelSequence; + +inline size_t label_sequence_hash(const LabelSequence& ls) { + return Core::MurmurHash3_x64_64(reinterpret_cast(ls.data()), + ls.size() * sizeof(LabelSequence::value_type), 0x78b174eb); +} + +// Note: all history have to inherit from LabelHistoryBase +struct LabelHistoryBase { + size_t ref_count, cacheHash; + LabelSequence labelSeq; // always right-most latest + + LabelHistoryBase() + : ref_count(0), cacheHash(0) {} + LabelHistoryBase(const LabelHistoryBase& ref) + : ref_count(0), cacheHash(0), labelSeq(ref.labelSeq) {} + + virtual ~LabelHistoryBase() = default; +}; + +typedef LabelHistoryBase* LabelHistoryHandle; +typedef std::unordered_map HistoryCache; +typedef std::pair CacheUpdateResult; + +// LabelHistoryObject handling (caching, reference counting and clean up ...): all inline +class LabelHistoryManager { +public: + LabelHistoryManager() {} + ~LabelHistoryManager() { + verify_(cache_.empty()); + } + + LabelHistory history(LabelHistoryHandle lhd) const; + void reset() { + cache_.clear(); + } + + bool isEqualSequence(const LabelHistoryHandle lhd, const LabelHistoryHandle rhd) const { + return label_sequence_hash(lhd->labelSeq) == label_sequence_hash(rhd->labelSeq); + } + + bool isEqualSequence(const LabelHistoryHandle lhd, LabelIndex lIdx, const LabelHistoryHandle rhd) const { + return extendedHashKey(lhd, lIdx) == label_sequence_hash(rhd->labelSeq); + } + + const HistoryCache& historyCache() { + return cache_; + } + // check existence for to-be-extended history + CacheUpdateResult checkCache(const LabelHistoryHandle lhd, LabelIndex lIdx, u32 updateHash); + CacheUpdateResult checkCache(const LabelHistoryHandle lhd, u32 updateHash); + CacheUpdateResult updateCache(LabelHistoryHandle lhd, u32 updateHash); + + size_t hashKey(const LabelHistoryHandle lhd) const { + return label_sequence_hash(lhd->labelSeq); + } + + size_t reducedHashKey(const LabelSequence& labelSeq, s32 limit) const; + size_t reducedHashKey(const LabelHistoryHandle lhd, s32 limit) const { + return reducedHashKey(lhd->labelSeq, limit); + } + + size_t extendedHashKey(const LabelHistoryHandle lhd, LabelIndex lIdx) const; + size_t reducedExtendedHashKey(const LabelHistoryHandle lhd, s32 limit, LabelIndex lIdx) const; + +protected: + friend class LabelHistory; + LabelHistoryHandle acquire(LabelHistoryHandle lhd) const; + void release(LabelHistoryHandle lhd) const; + +private: + mutable HistoryCache cache_; +}; + +class LabelHistory { +public: + LabelHistory() + : mang_(0), desc_(0) {} + LabelHistory(const LabelHistory& ref) + : mang_(ref.mang_), desc_(ref.desc_) { + if (desc_) + mang_->acquire(desc_); + } + + ~LabelHistory() { + if (desc_) + mang_->release(desc_); + } + + const LabelHistory& operator=(const LabelHistory& rhs); + + const LabelHistoryManager* manager() const { + return mang_; + } + const LabelHistoryHandle handle() const { + return desc_; + } + + bool isValid() const { + return mang_ != 0; + } + + size_t hashKey() const; + size_t reducedHashKey(s32 limit) const; + size_t reducedExtendedHashKey(s32 limit, LabelIndex lIdx) const; + + struct Hash { + inline size_t operator()(const LabelHistory& lh) const { + return lh.isValid() ? lh.hashKey() : 0; + } + }; + + LabelIndex getLastLabel() const; + + // debug + void format() const; + +private: + friend class LabelHistoryManager; + LabelHistory(const LabelHistoryManager* lhm, LabelHistoryHandle lhd) + : mang_(lhm), desc_(mang_->acquire(lhd)) {} + +private: + const LabelHistoryManager* mang_; + LabelHistoryHandle desc_; +}; + +inline LabelHistoryHandle LabelHistoryManager::acquire(LabelHistoryHandle lhd) const { + if (lhd) + ++(lhd->ref_count); + return lhd; +} + +inline void LabelHistoryManager::release(LabelHistoryHandle lhd) const { + if (lhd) { + require_gt(lhd->ref_count, 0); + --(lhd->ref_count); + if (lhd->ref_count == 0) { + // remove from cache + cache_.erase(lhd->cacheHash); + delete lhd; + } + } +} + +inline size_t LabelHistoryManager::reducedHashKey(const LabelSequence& labelSeq, s32 limit) const { + if (limit < 0 || (u32)limit >= labelSeq.size()) + return label_sequence_hash(labelSeq); + LabelSequence reducedLabelSeq(labelSeq.end() - limit, labelSeq.end()); + return label_sequence_hash(reducedLabelSeq); +} + +inline size_t LabelHistoryManager::extendedHashKey(const LabelHistoryHandle lhd, LabelIndex lIdx) const { + LabelSequence extendedLabelSeq(lhd->labelSeq); + extendedLabelSeq.push_back(lIdx); + return label_sequence_hash(extendedLabelSeq); +} + +inline size_t LabelHistoryManager::reducedExtendedHashKey(const LabelHistoryHandle lhd, s32 limit, LabelIndex lIdx) const { + if (limit < 0 || (u32)limit > lhd->labelSeq.size()) + return extendedHashKey(lhd, lIdx); + LabelSequence reducedLabelSeq(lhd->labelSeq.end() - (limit - 1), lhd->labelSeq.end()); + reducedLabelSeq.push_back(lIdx); + return label_sequence_hash(reducedLabelSeq); +} + +inline LabelHistory LabelHistoryManager::history(LabelHistoryHandle lhd) const { + return LabelHistory(this, lhd); +} + +// check existence for to-be-extended history +inline CacheUpdateResult LabelHistoryManager::checkCache(const LabelHistoryHandle lhd, + LabelIndex lIdx, u32 updateHash) { + size_t hash = updateHashKey(extendedHashKey(lhd, lIdx), updateHash); + HistoryCache::iterator iter = cache_.find(hash); + return std::make_pair(iter, iter != cache_.end()); +} + +inline CacheUpdateResult LabelHistoryManager::checkCache(const LabelHistoryHandle lhd, u32 updateHash) { + size_t hash = updateHashKey(hashKey(lhd), updateHash); + HistoryCache::iterator iter = cache_.find(hash); + return std::make_pair(iter, iter != cache_.end()); +} + +inline CacheUpdateResult LabelHistoryManager::updateCache(LabelHistoryHandle lhd, u32 updateHash) { + size_t hash = updateHashKey(hashKey(lhd), updateHash); + lhd->cacheHash = hash; + return cache_.insert(std::make_pair(hash, lhd)); +} + +inline const LabelHistory& LabelHistory::operator=(const LabelHistory& rhs) { + if (rhs.desc_) + rhs.mang_->acquire(rhs.desc_); + if (desc_) + mang_->release(desc_); + mang_ = rhs.mang_; + desc_ = rhs.desc_; + return *this; +} + +inline size_t LabelHistory::hashKey() const { + if (desc_) + return mang_->hashKey(desc_); + return 0; +} + +inline size_t LabelHistory::reducedHashKey(s32 limit) const { + if (desc_ && limit != 0) + return mang_->reducedHashKey(desc_, limit); + return 0; +} + +inline size_t LabelHistory::reducedExtendedHashKey(s32 limit, LabelIndex lIdx) const { + if (desc_ && limit != 0) + return mang_->reducedExtendedHashKey(desc_, limit, lIdx); + return 0; +} + +inline LabelIndex LabelHistory::getLastLabel() const { + if (desc_ && !desc_->labelSeq.empty()) + return desc_->labelSeq.back(); + return Core::Type::max; +} + +// debug +inline void LabelHistory::format() const { + std::cout << " LabelHistory: "; + if (desc_) + for (LabelIndex label : desc_->labelSeq) + std::cout << label << " "; + std::cout << std::endl; +} + +} // namespace Nn + +#endif diff --git a/src/Nn/LabelScorer.cc b/src/Nn/LabelScorer.cc index ff1d531f..e9caf575 100644 --- a/src/Nn/LabelScorer.cc +++ b/src/Nn/LabelScorer.cc @@ -14,5 +14,353 @@ */ #include "LabelScorer.hh" +#include "Prior.hh" + +#include using namespace Nn; + +const Core::ParameterString LabelScorer::paramLabelFile( + "label-file", + "label index mapping file", + ""); + +const Core::ParameterInt LabelScorer::paramNumOfClasses( + "number-of-classes", + "number of classes (network output)", + 0); + +const Core::ParameterInt LabelScorer::paramBufferSize( + "buffer-size", + "buffer-wise encoding/decoding (online fashion)", + Core::Type::max); + +const Core::ParameterFloat LabelScorer::paramScale( + "scale", + "scaling for the label scores", + 1.0); + +const Core::ParameterBool LabelScorer::paramUsePrior( + "use-prior", + "whether to use prior", + false); + +const Core::ParameterInt LabelScorer::paramPriorContextSize( + "prior-context-size", + "label context size for prior", + 0, 0); + +const Core::ParameterBool LabelScorer::paramLoopUpdateHistory( + "loop-update-history", + "whether label loop should update label sequence of history (dependency)", + false); + +const Core::ParameterBool LabelScorer::paramBlankUpdateHistory( + "blank-update-history", + "whether blank label should update label sequence of history (dependency)", + false); + +const Core::ParameterBool LabelScorer::paramPositionDependent( + "position-dependent", + "whether model is position dependent", + false); + +const Core::ParameterIntVector LabelScorer::paramReductionFactors( + "reduction-factors", + "input (time) reduction factors of each downsampling layer to compute the maximum length", + ",", 1); + +const Core::ParameterBool LabelScorer::paramUseStartLabel( + "use-start-label", + "force start label to present for start history", + false); + +// only for segmental decoding +const Core::ParameterFloat LabelScorer::paramSegmentLengthScale( + "segment-length-scale", + "scaling for the segment length score", + 1.0); + +const Core::ParameterInt LabelScorer::paramMinSegmentLength( + "min-segment-length", + "minimum segment length in frames (encodings)", + 1); + +const Core::ParameterInt LabelScorer::paramMaxSegmentLength( + "max-segment-length", + "maximum segment length in frames (encodings)", + 20); + +LabelScorer::LabelScorer(const Core::Configuration& config) + : Core::Component(config), + dependency_(paramLabelFile(config)), + redFactors_(paramReductionFactors(config)), + scale_(paramScale(config)), + numClasses_(paramNumOfClasses(config)), + usePrior_(paramUsePrior(config)), + priorContextSize_(paramPriorContextSize(config)), + loopUpdateHistory_(paramLoopUpdateHistory(config)), + blankUpdateHistory_(paramBlankUpdateHistory(config)), + needEndProcessing_(false), + isPositionDependent_(paramPositionDependent(config)), + useStartLabel_(paramUseStartLabel(config)), + startLabelIndex_(Core::Type::max), + startPosition_(0), // not configurable, but rather model specific + segLenScale_(paramSegmentLengthScale(config)), + minSegLen_(paramMinSegmentLength(config)), + maxSegLen_(paramMaxSegmentLength(config)), + bufferSize_(paramBufferSize(config)) { + init(); + reset(); +} + +void LabelScorer::init() { + labelHistoryManager_ = new LabelHistoryManager(); + + if (numClasses_ == 0) { + log() << "no number-of-classes given, try to get it from label-file"; + getLabelIndexMap(); + } + log() << "number of classes: " << numClasses_; + + if (usePrior_ && priorContextSize_ == 0) { + // Note: prior scale independent of posterior scale + log() << "use context-independent label pirors"; + Prior prior(config); + if (!prior.fileName().empty()) + prior.read(); + else + criticalError() << "no prior file provided"; + u32 size = prior.size(); + verify(size >= numClasses_); + logPriors_.reserve(size); + for (u32 idx = 0; idx < size; ++idx) + logPriors_.push_back(prior.scale() * prior.at(idx)); + log() << "logPrior scale: " << prior.scale(); + } +} + +void LabelScorer::reset() { + inputBuffer_.clear(); + nInput_ = 0; + eos_ = false; + decodeStep_ = 0; + segmentScore_.clear(); + + labelHistoryManager_->reset(); +} + +const LabelIndexMap& LabelScorer::getLabelIndexMap() { + if (!labelIndexMap_.empty()) { + verify(numClasses_ > 0); + return labelIndexMap_; + } + + std::string labelFile = paramLabelFile(config); + if (labelFile.empty()) + criticalError() << "no label file provided"; + else + log() << "load label and index from file " << labelFile; + + u32 nClasses = 0; + std::ifstream input(labelFile, std::ios::in); + std::string line; + while (input.good()) { + std::getline(input, line); + if (line.empty()) + continue; + std::stringstream ss(line); + std::string label; + LabelIndex idx; + ss >> label; + ss >> idx; + labelIndexMap_[label] = idx; + nClasses = std::max(nClasses, idx); + } + if (numClasses_ > 0) + verify(nClasses + 1 == numClasses_); + else + numClasses_ = nClasses + 1; + + return labelIndexMap_; +} + +LabelIndex LabelScorer::getSpecialLabelIndex(const std::string& label, const std::string& name) const { + if (labelIndexMap_.count(label) > 0) { + return labelIndexMap_.at(label); + } + else { + Core::ParameterInt paramLabelIndex(name.c_str(), "", Core::Type::max); + LabelIndex index = paramLabelIndex(config); + return index; + } +} + +LabelIndex LabelScorer::getNoContextLabelIndex() const { + LabelIndex index = getEndLabelIndex(); + if (index == Core::Type::max) + index = getBlankLabelIndex(); + if (index == Core::Type::max) { + // if neither eos nor blank, then probably silence (need to specify) + Core::ParameterInt paramLabelIndex("no-context-label-index", "", Core::Type::max); + index = paramLabelIndex(config); + } + return index; +} + +u32 LabelScorer::getReducedLength(u32 len) const { + for (u32 idx = 0; idx < redFactors_.size(); ++idx) + len = (len + redFactors_[idx] - 1) / redFactors_[idx]; + return len; +} + +bool LabelScorer::reachEnd() const { + if (needEndProcessing_ || !bufferFilled()) { + return false; + } + else { + u32 len = inputBuffer_.size(); + // adjust to downsampled input length (including 0-padding) + if (!redFactors_.empty()) + len = getReducedLength(len); + return decodeStep_ >= len; + } +} + +u32 LabelScorer::getEncoderLength() const { + // more to come + if (!eos_) + return Core::Type::max; + u32 len = nInput_; + // adjust to downsampled input length (including 0-padding) + if (!redFactors_.empty()) + len = getReducedLength(len); + return len + 1; // plus 1 for ending +} + +bool LabelScorer::maybeFinalSegment(u32 startPos) const { + if (!isPositionDependent_) + return false; + u32 remainLen = getEncoderLength() - 1 - startPos; + return remainLen >= minSegLen_ && remainLen <= maxSegLen_; +} + +// input: vector of log(p) => output: log( sum_p ) +Score LabelScorer::logSumExp(const std::vector& scores) { + Score max = *(std::max_element(scores.begin(), scores.end())); + verify(!std::isinf(max)); + Score sum = 0.0; + for (std::vector::const_iterator iter = scores.begin(); iter != scores.end(); ++iter) + sum += std::exp(*iter - max); + return std::log(sum) + max; +} + +// logSumExp in -log() domain: more efficient for more than 2 terms +Score LabelScorer::computeScoreSum(const std::vector& scores) { + Score best = *(std::min_element(scores.begin(), scores.end())); + verify(best < Core::Type::max); // 0-prob defined in RASR + Score expSum = 0.0; + for (std::vector::const_iterator iter = scores.begin(); iter != scores.end(); ++iter) + if (*iter != Core::Type::max) // filter invalid ones + expSum += std::exp(best - *iter); + return -std::log(expSum) + best; +} + +// ---------------------------- PrecomputedScorer ----------------------------- +const Core::ParameterBool PrecomputedScorer::paramFirstOrder("first-order", "", false); + +PrecomputedScorer::PrecomputedScorer(const Core::Configuration& config) + : Core::Component(config), + Precursor(config), + firstOrder_(paramFirstOrder(config)) { + log() << "use precomputed scorer (log-posterior)"; + redFactors_.clear(); // input is already reduced + isPositionDependent_ = false; + + if (firstOrder_) { + log() << "as 1st-order model score caching"; + useStartLabel_ = true; + startLabelIndex_ = getStartLabelIndex(); + verify(startLabelIndex_ != Core::Type::max); + log() << "use start label index " << startLabelIndex_; + + cachedScore_.resize(numClasses_); + cachedHistory_.resize(numClasses_, nullptr); + } + + blankLabelIndex_ = getBlankLabelIndex(); +} + +void PrecomputedScorer::addInput(Core::Ref f) { + Precursor::addInput(f); + if (inputBuffer_.size() == 1) { + if (firstOrder_) + verify(inputBuffer_.front().size() >= numClasses_ * numClasses_); + else + verify(inputBuffer_.front().size() >= numClasses_); + } + + // log(p) + std::vector& scores = inputBuffer_.back(); + // -alpha * log(p) + optional beta * log(prior) + std::transform(scores.begin(), scores.end(), scores.begin(), + std::bind(std::multiplies(), std::placeholders::_1, -scale_)); + if (usePrior_ && priorContextSize_ == 0) { + verify(scores.size() == logPriors_.size()); + std::transform(scores.begin(), scores.end(), logPriors_.begin(), scores.begin(), + std::plus()); + } +} + +LabelHistory PrecomputedScorer::startHistory() { + if (!firstOrder_) + return labelHistoryManager_->history(0); + + LabelHistoryDescriptor* lhd = getHistory(startLabelIndex_); + return labelHistoryManager_->history(lhd); +} + +void PrecomputedScorer::extendLabelHistory(LabelHistory& h, LabelIndex idx, + u32 position, bool isLoop) { + if (firstOrder_) { + if ((idx == blankLabelIndex_ && !blankUpdateHistory_) || (isLoop && !loopUpdateHistory_)) + return; + LabelHistoryDescriptor* lhd = getHistory(idx); + h = labelHistoryManager_->history(lhd); + } +} + +PrecomputedScorer::LabelHistoryDescriptor* PrecomputedScorer::getHistory(LabelIndex idx) { + LabelHistoryDescriptor* lhd = cachedHistory_.at(idx); + if (lhd == nullptr) { + lhd = new LabelHistoryDescriptor(); + lhd->labelSeq.push_back(idx); + CacheUpdateResult result = labelHistoryManager_->updateCache(lhd, 0); + verify(result.second); + ++(lhd->ref_count); // always kept in cache + cachedHistory_[idx] = lhd; + } + return lhd; +} + +const std::vector& PrecomputedScorer::getScores(const LabelHistory& h, bool isLoop) { + const std::vector& scores = inputBuffer_.at(decodeStep_); + if (!firstOrder_) + return scores; + + LabelIndex idx = h.getLastLabel(); + std::vector& cs = cachedScore_[idx]; + if (cs.empty()) { + cs.resize(numClasses_); + u32 start = idx * numClasses_; + std::copy(scores.begin() + start, scores.begin() + start + numClasses_, cs.begin()); + } + return cs; +} + +void PrecomputedScorer::cleanUpBeforeExtension(u32 minPos) { + if (firstOrder_) { + cachedScore_.clear(); + cachedScore_.resize(numClasses_); + } +} diff --git a/src/Nn/LabelScorer.hh b/src/Nn/LabelScorer.hh index 38d8ea6c..204c7dd1 100644 --- a/src/Nn/LabelScorer.hh +++ b/src/Nn/LabelScorer.hh @@ -18,13 +18,229 @@ #include #include +#include +#include +#include "LabelHistoryManager.hh" namespace Nn { +typedef Search::Score Score; +typedef std::vector> SegmentScore; +typedef std::unordered_map LabelIndexMap; // base class of models for label scoring (basic supports except scoring) class LabelScorer : public virtual Core::Component, public Core::ReferenceCounted { +public: + // config params + static const Core::ParameterString paramLabelFile; + static const Core::ParameterInt paramNumOfClasses; + static const Core::ParameterBool paramPreComputeEncoding; + static const Core::ParameterInt paramBufferSize; + static const Core::ParameterFloat paramScale; + static const Core::ParameterBool paramUsePrior; + static const Core::ParameterInt paramPriorContextSize; + static const Core::ParameterBool paramLoopUpdateHistory; + static const Core::ParameterBool paramBlankUpdateHistory; + static const Core::ParameterBool paramPositionDependent; + static const Core::ParameterIntVector paramReductionFactors; + static const Core::ParameterBool paramUseStartLabel; + static const Core::ParameterFloat paramSegmentLengthScale; + static const Core::ParameterInt paramMinSegmentLength; + static const Core::ParameterInt paramMaxSegmentLength; + +public: + LabelScorer(const Core::Configuration&); + virtual ~LabelScorer() { + delete labelHistoryManager_; + } + + const Core::Dependency& getDependency() const { + return dependency_; + } + + virtual void reset(); + virtual void cleanUpBeforeExtension(u32 minPos) {} // each search step + + // labels + LabelIndex numClasses() const { + return numClasses_; + } + const LabelIndexMap& getLabelIndexMap(); + // special labels: either in the vocab file or configurable (hard-coded naming) + LabelIndex getBlankLabelIndex() const { + return getSpecialLabelIndex("", "blank-label-index"); + } + LabelIndex getStartLabelIndex() const { + return getSpecialLabelIndex("", "start-label-index"); + } + LabelIndex getEndLabelIndex() const { + return getSpecialLabelIndex("", "end-label-index"); + } + LabelIndex getUnknownLabelIndex() const { + return getSpecialLabelIndex("", "unknown-label-index"); + } + LabelIndex getNoContextLabelIndex() const; + + // special flags for various models, e.g. attention, segmental, RNN-T + bool needEndProcess() const { + return needEndProcessing_ || isPositionDependent_; + } + bool isPositionDependent() const { + return isPositionDependent_; + } + virtual bool useRelativePosition() const { + return false; + } + virtual bool useVerticalTransition() const { + return false; + } + + // inputs + virtual void addInput(Core::Ref f) { + inputBuffer_.emplace_back(*(f->mainStream().get())); + ++nInput_; + } + virtual void clearBuffer() { + inputBuffer_.clear(); + decodeStep_ = 0; + } + u32 bufferSize() const { + return inputBuffer_.size(); + } + virtual bool bufferFilled() const { + return eos_ || inputBuffer_.size() >= bufferSize_; + } + void setEOS() { + eos_ = true; + } + bool reachEOS() const { + return eos_; + } + + virtual void increaseDecodeStep() { + ++decodeStep_; + } + // stopping criteria + // - needEndProcessing_: stop by search (additional max input length stop) + // - time synchronous: stop by decodeStep reach end + virtual bool reachEnd() const; + virtual bool maybeFinalSegment(u32 startPos) const; + + // naming after encoder-decoder framework, but can also be beyond + virtual void encode() {} + virtual u32 getEncoderLength() const; + + // ---- label history and scores ---- + virtual bool isHistoryDependent() const { + return true; + } + virtual bool loopUpdateHistory() const { + return loopUpdateHistory_; + } + virtual bool blankUpdateHistory() const { + return blankUpdateHistory_; + } + + // start up label history (to be overwritten) + virtual LabelHistory startHistory() = 0; + + // extend history and possibly update caching (to be overwritten) + virtual void extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop) = 0; + + // get label scores for the next output position (to be overwritten) + virtual const std::vector& getScores(const LabelHistory& h, bool isLoop) = 0; + + // get segment scores for the next label segment given start position + virtual const SegmentScore& getSegmentScores(const LabelHistory& h, LabelIndex segId, + u32 startPos) { + return segmentScore_; + } + // --------------------------- + + static Score logSumExp(const std::vector& scores); + static Score computeScoreSum(const std::vector& scores); + +protected: + void init(); // Note: not virtual + LabelIndex getSpecialLabelIndex(const std::string&, const std::string&) const; + u32 getReducedLength(u32 len) const; // input length after possible downsampling + +protected: + LabelHistoryManager* labelHistoryManager_; + Core::Dependency dependency_; + + std::vector> inputBuffer_; // hard coded Mm::FeatureType = f32 + u32 nInput_; // total number of inputs + std::vector redFactors_; // input (time) reduction factors + bool eos_; // end of input stream + + f32 scale_; + LabelIndex numClasses_; + + // prior for model bias correction + bool usePrior_; + u32 priorContextSize_; + std::vector logPriors_; // context-independent prior + + bool loopUpdateHistory_; + bool blankUpdateHistory_; + bool needEndProcessing_; + bool isPositionDependent_; + + bool useStartLabel_; + LabelIndex startLabelIndex_; + s32 startPosition_; + u32 decodeStep_; // global decoding step + + // for segmental decoding + SegmentScore segmentScore_; + f32 segLenScale_; + u32 minSegLen_; + u32 maxSegLen_; // speech only + +private: + LabelIndexMap labelIndexMap_; + u32 bufferSize_; // maximum number for input frames +}; + +// posteriors computed beforehand, e.g. front-end forwarding +// - compatible with any 0-order (or + simple TDP) time-synchronized model (hybrid, ctc, etc.) +// - also support 1st-order model as cached scores for all context (vocab^2) +class PrecomputedScorer : public LabelScorer { + typedef LabelScorer Precursor; + typedef LabelHistoryBase LabelHistoryDescriptor; + +public: + static const Core::ParameterBool paramFirstOrder; + +public: + PrecomputedScorer(const Core::Configuration&); + + // input log posterior scores + void addInput(Core::Ref f); + + // no or 1st-order history + bool isHistoryDependent() const { + return firstOrder_; + } + LabelHistory startHistory(); + void extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop); + + // get label scores for the next output position + const std::vector& getScores(const LabelHistory& h, bool isLoop); + + void cleanUpBeforeExtension(u32 minPos); + +private: + LabelHistoryDescriptor* getHistory(LabelIndex idx); + +private: + bool firstOrder_; + std::vector> cachedScore_; // avoid redundant copy + std::vector cachedHistory_; // quick access + + LabelIndex blankLabelIndex_; }; } // namespace Nn diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index efb1485d..1ff7c587 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -88,12 +88,68 @@ Core::FormatSet& Module_::formats() { return *formats_; } +namespace { +enum LabelScorerType { + // precomputed in front-end flow + PrecomputedLogPosteriorType, + // so far only tensorflow-based models + TFAttentionType, + TFRnnTransducerType, + TFFfnnTransducerType, + TFSegmentalType +}; + +const Core::Choice labelScorerTypeChoice( + "precomputed-log-posterior", PrecomputedLogPosteriorType, + "tf-attention", TFAttentionType, + "tf-rnn-transducer", TFRnnTransducerType, + "tf-ffnn-transducer", TFFfnnTransducerType, + "tf-segmental", TFSegmentalType, + Core::Choice::endMark()); + +const Core::ParameterChoice paramLabelScorerType( + "label-scorer-type", &labelScorerTypeChoice, + "select label scorer type", + PrecomputedLogPosteriorType); +} // namespace Core::Ref Module_::createLabelScorer(const Core::Configuration& config) const { #ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH - LabelScorer* labelScorer = nullptr; - return Core::ref(labelScorer); + LabelScorer* labelScorer = nullptr; + LabelScorerType type = static_cast(paramLabelScorerType(config)); + switch (type) { + case PrecomputedLogPosteriorType: + labelScorer = new PrecomputedScorer(config); + break; +#ifdef MODULE_TENSORFLOW + case TFAttentionType: + labelScorer = new TFAttentionModel(config); + break; + case TFRnnTransducerType: + labelScorer = new TFRnnTransducer(config); + break; + case TFFfnnTransducerType: + labelScorer = new TFFfnnTransducer(config); + break; + // TODO + case TFSegmentalType: + Core::Application::us()->criticalError("tf-segmental not implemented yet !"); + break; +#else + case TFAttentionType: + case TFRnnTransducerType: + case TFFfnnTransducerType: + case TFSegmentalType: + Core::Application::us()->criticalError("Module MODULE_TENSORFLOW not available!"); + break; +#endif + default: + Core::Application::us()->criticalError("Unknown label-scorer-type "); + break; + } + verify(labelScorer); + return Core::ref(labelScorer); #else - Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); + Core::Application::us()->criticalError("Module MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH not available!"); #endif } diff --git a/src/Nn/TFLabelScorer.cc b/src/Nn/TFLabelScorer.cc index 8fd29602..5483cc8c 100644 --- a/src/Nn/TFLabelScorer.cc +++ b/src/Nn/TFLabelScorer.cc @@ -14,5 +14,1191 @@ */ #include "TFLabelScorer.hh" +#include "Prior.hh" using namespace Nn; + +const Core::ParameterBool TFModelBase::paramTransformOuputLog( + "transform-output-log", + "apply log to tensorflow output", + false); + +const Core::ParameterBool TFModelBase::paramTransformOuputNegate( + "transform-output-negate", + "negate tensorflow output (after log)", + false); + +const Core::ParameterInt TFModelBase::paramMaxBatchSize( + "max-batch-size", + "maximum number of histories forwarded in one go", + 64, 1); + +TFModelBase::TFModelBase(const Core::Configuration& config) + : Core::Component(config), + Precursor(config), + session_(select("session")), + loader_(Tensorflow::Module::instance().createGraphLoader(select("loader"))), + graph_(loader_->load_graph()), // tf::GraphDef, libraries and necessary param names + maxBatchSize_(paramMaxBatchSize(config)) { + bool transform_output_log = paramTransformOuputLog(config); + bool transform_output_negate = paramTransformOuputNegate(config); + if (transform_output_log && transform_output_negate) { + decoding_output_transform_function_ = [](Score v, Score scale) { return -scale * std::log(v); }; + log() << "apply -log(.) to model output"; + } + else if (transform_output_log) { + decoding_output_transform_function_ = [](Score v, Score scale) { return scale * std::log(v); }; + log() << "apply log(.) to model output"; + } + else if (transform_output_negate) { + decoding_output_transform_function_ = [](Score v, Score scale) { return -scale * v; }; + log() << "apply -(.) to model output"; + } + else if (scale_ != 1.0) { + decoding_output_transform_function_ = [](Score v, Score scale) { return scale * v; }; + } + + init(); + reset(); + + // debug + Core::ParameterBool paramDebug("debug", "", false); + debug_ = paramDebug(config); +} + +TFModelBase::~TFModelBase() { + reset(); + delete startHistoryDescriptor_; +} + +void TFModelBase::reset() { + Precursor::reset(); + batch_.clear(); + cacheHashQueue_.clear(); +} + +void TFModelBase::init() { + // create tf::Session with graph(tf::GraphDef) and default initialization of variables + session_.addGraph(*graph_); + // restore model checkpoint + loader_->initialize(session_); + + // --- encoder --- + Tensorflow::TensorInputMap featureInputMap(select("feature-input-map")); + const Tensorflow::TensorInputInfo& info = featureInputMap.get_info("feature"); + encoding_input_tensor_name_ = info.tensor_name(); + if (!info.seq_length_tensor_name().empty()) + encoding_input_seq_length_tensor_name_ = info.seq_length_tensor_name(); + else + encoding_input_seq_length_tensor_name_.clear(); + + // --- decoder --- + initDecoder(); + + // --- step ops --- + encoding_ops_ = graph_->encoding_ops(); + decoding_ops_ = graph_->decoding_ops(); + var_update_ops_ = graph_->update_ops(); + var_post_update_ops_ = graph_->post_update_ops(); + + // each stochastic_var_scores has a corresponding decoding_op + verify(decoding_output_tensor_names_.size() == decoding_ops_.size()); + + // unique start history handle + initStartHistory(); + + // optional static context-dependent prior + if (usePrior_ && priorContextSize_ > 0) + loadPrior(); +} + +void TFModelBase::initDecoder() { + // label-dependent variables (stored in the graph and can be assigned/fetched) + for (const std::string& s : graph_->decoder_input_vars()) { + const auto& var = graph_->getVariable(s); + decoding_input_tensor_names_.push_back(var.initial_value_name); + var_feed_names_.push_back(var.initial_value_name); + var_feed_ops_.push_back(var.initializer_name); + u32 ndim = var.shape.size(); + verify(ndim >= 1); + decoding_input_ndims_.push_back(ndim); + } + + for (const std::string& s : graph_->decoder_output_vars()) { + const auto& var = graph_->getVariable(s); + decoding_output_tensor_names_.push_back(var.snapshot_name); + u32 ndim = var.shape.size(); + verify(ndim >= 1); + decoding_output_ndims_.push_back(ndim); + } + + for (const std::string& s : graph_->state_vars()) { + const auto& var = graph_->getVariable(s); + var_feed_names_.push_back(var.initial_value_name); + var_feed_ops_.push_back(var.initializer_name); + var_fetch_names_.push_back(var.snapshot_name); + } + verify(var_fetch_names_.size() == var_feed_names_.size() - decoding_input_tensor_names_.size()); + + for (const std::string& s : graph_->global_vars()) { + const auto& var = graph_->getVariable(s); + global_var_feed_names_.push_back(var.initial_value_name); + global_var_feed_ops_.push_back(var.initializer_name); + } +} + +// also allow (truncated) context-dependent prior (prior scale independent of posterior scale) +void TFModelBase::loadPrior() { + if (!usePrior_ || priorContextSize_ == 0) + return; + + log() << "use context-dependent label pirors (context-size:" << priorContextSize_ << ")"; + Prior prior(config); + if (prior.fileName().empty()) + error() << "no prior file provided"; + log() << "logPrior scale: " << prior.scale(); + std::string baseName = prior.fileName(); + + // sentence begin context: replace invalid context instead of append new + // always assume useStartLabel_: all-0 embedding can also be achieved with safe embedding + verify(useStartLabel_); + LabelIndex noCtxId = getNoContextLabelIndex(); + if (startLabelIndex_ >= numClasses_) + verify(noCtxId < numClasses_); + + // theoretically any context size: generate all permutations of label sequence (column-wise) + // Note: memory cost for higher order context (speed is not crucial for init) + std::vector> context(priorContextSize_); + u32 size = std::pow(numClasses_, priorContextSize_); + for (u32 ctx = 0; ctx < priorContextSize_; ++ctx) { + // repeat each label within a block and fill in the column with repeating block + u32 labelRepeat = std::pow(numClasses_, priorContextSize_ - ctx - 1); + std::vector block; + block.reserve(labelRepeat * numClasses_); + for (u32 cId = 0; cId < numClasses_; ++cId) { + std::vector vec(labelRepeat, cId); + if (cId == noCtxId) + vec.assign(labelRepeat, startLabelIndex_); + block.insert(block.end(), vec.begin(), vec.end()); + } + context[ctx].reserve(size); + while (context[ctx].size() < size) + context[ctx].insert(context[ctx].end(), block.begin(), block.end()); + verify(context[ctx].size() == size); + } + + // loop over all unique context: load context-dependent prior + for (u32 idx = 0; idx < size; ++idx) { + // Note: fixed format for simplicity (e.g. path/prior.3-2-1.xml) right-most latest + LabelSequence labelSeq; + std::string name = baseName + "."; + bool valid = true; + for (u32 ctx = 0; ctx < priorContextSize_; ++ctx) { + u32 cId = context[ctx][idx]; + if (cId == noCtxId) + valid = false; + labelSeq.push_back(cId); + name += std::to_string(cId) + "-"; + } + if (!valid) + continue; + name.pop_back(); + name += ".xml"; + if (!prior.read(name)) { + // actually may be skipped on purose for impossible context + warning() << "failed to read " << name << " : skip this prior"; + continue; + } + verify(prior.size() == numClasses_); + std::vector& logPrior = contextLogPriors_[label_sequence_hash(labelSeq)]; + verify(logPrior.empty()); + logPrior.reserve(numClasses_); + for (u32 cId = 0; cId < numClasses_; ++cId) + logPrior.push_back(prior.scale() * prior.at(cId)); + } + + log() << "successfully loaded " << contextLogPriors_.size() << " context-dependent label pirors"; +} + +// compute encoding and initialize prev_state_vars in the graph +void TFModelBase::encode() { + if (inputBuffer_.empty()) { + warning() << "no features to feed to encoder ?!"; + return; + } + + log() << "encode input features (" << inputBuffer_[0].size() << ", " + << inputBuffer_.size() << ")"; + + MappedTensorList inputs; + std::vector> batchMat; // single sequence: D * T + batchMat.emplace_back(inputBuffer_[0].size(), inputBuffer_.size()); + for (u32 idx = 0, size = inputBuffer_.size(); idx < size; ++idx) { + const std::vector& f = inputBuffer_[idx]; + std::copy(f.begin(), f.end(), &(batchMat.front().at(0, idx))); + } + inputs.emplace_back(std::make_pair(encoding_input_tensor_name_, + Tensorflow::Tensor::create(batchMat, true))); + if (!encoding_input_seq_length_tensor_name_.empty()) { + std::vector seq_length({static_cast(inputBuffer_.size())}); + inputs.emplace_back(std::make_pair(encoding_input_seq_length_tensor_name_, + Tensorflow::Tensor::create(seq_length))); + } + + // init all stat vars including the encoding states (stored in the graph now) + // Note: tile_batch automatically done in the graph + session_.run(inputs, encoding_ops_); + + initComputation(); +} + +void TFModelBase::initComputation() { + LabelHistoryDescriptor* lhd = static_cast(startHistory().handle()); + verify(lhd->scores.empty()); + if (useStartLabel_) { + // not using makeBatch, still need to compute scores later with start label input + batch_.push_back(lhd); + } + else { + makeBatch(lhd); + verify(batch_.size() == 1); + // compute the first score based on default initialized states + computeBatchScores(); + } + // obtain initialized/updated states to startHistory (type/size all hidden in Tensor) + fetchBatchVariables(); + batch_.clear(); +} + +void TFModelBase::initStartHistory() { + startLabelIndex_ = getStartLabelIndex(); + if (useStartLabel_) { + verify(startLabelIndex_ != Core::Type::max); + log() << "use start label index " << startLabelIndex_; + } + startHistoryDescriptor_ = new LabelHistoryDescriptor(); + startHistoryDescriptor_->labelSeq.push_back(startLabelIndex_); + startHistoryDescriptor_->variables.resize(var_fetch_names_.size()); + // + other possible unified operations (if always the same) +} + +LabelHistory TFModelBase::startHistory() { + LabelHistoryDescriptor* lhd = new LabelHistoryDescriptor(*startHistoryDescriptor_); + CacheUpdateResult result = labelHistoryManager_->updateCache(lhd, startPosition_); + if (result.second) { + cacheHashQueue_.push_back(lhd->cacheHash); + } + else { + verify_(labelHistoryManager_->isEqualSequence(lhd, result.first->second)); + delete lhd; + lhd = static_cast(result.first->second); + } + return labelHistoryManager_->history(lhd); +} + +void TFModelBase::extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop) { + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + // check without creating new (avoid lots of copying) + CacheUpdateResult result = labelHistoryManager_->checkCache(lhd, idx, position); + LabelHistoryDescriptor* nlhd; + if (result.second) { + // existing one: ensure no hash colision w.r.t. position + verify_(labelHistoryManager_->isEqualSequence(lhd, idx, result.first->second)); + nlhd = static_cast(result.first->second); + } + else { // creating new (keep parent's states for next computation) + nlhd = new LabelHistoryDescriptor(*lhd); + nlhd->labelSeq.push_back(idx); + nlhd->isBlank = false; + nlhd->scores.clear(); + nlhd->position = position; + + result = labelHistoryManager_->updateCache(nlhd, position); + if (result.second) { + // caching newly extended label history for batch scoring + cacheHashQueue_.push_back(nlhd->cacheHash); + } + else { // this should not happen ?! + if (position != 0) + verify(labelHistoryManager_->isEqualSequence(nlhd, result.first->second)); + delete nlhd; + nlhd = static_cast(result.first->second); + } + } + h = labelHistoryManager_->history(nlhd); +} + +const std::vector& TFModelBase::getScores(const LabelHistory& h, bool isLoop) { + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + if (!lhd->scores.empty()) + return lhd->scores; + + makeBatch(lhd); + verify(batch_.size() > 0); + decodeBatch(); + + // results: maybe have more scores than numClasses for some special cases + verify(lhd->scores.size() >= numClasses_); + return lhd->scores; +} + +// oldest first, still active, uniq, not-scored +void TFModelBase::makeBatch(LabelHistoryDescriptor* targetLhd) { + batch_.push_back(targetLhd); + const HistoryCache& cache = labelHistoryManager_->historyCache(); + std::unordered_set batchHash; + while (batch_.size() < maxBatchSize_ && !cacheHashQueue_.empty()) { + size_t hash = cacheHashQueue_.front(); + cacheHashQueue_.pop_front(); + if (cache.count(hash) == 0 || batchHash.count(hash) > 0) + continue; + LabelHistoryDescriptor* lhd = static_cast(cache.at(hash)); + if (lhd == targetLhd || !lhd->scores.empty()) + continue; + batch_.push_back(lhd); + batchHash.insert(hash); + } +} + +void TFModelBase::decodeBatch() { + feedBatchVariables(); + updateBatchVariables(); + computeBatchScores(); + fetchBatchVariables(); + batch_.clear(); +} + +void TFModelBase::feedBatchVariables() { + if (var_feed_names_.empty()) + return; + + MappedTensorList inputs; + feedDecodeInput(inputs); + + // all labels are before state variables + u32 shift = decoding_input_tensor_names_.size(); + + // state variables + std::vector batchVars(batch_.size(), nullptr); + for (u32 vIdx = 0, vSize = var_feed_names_.size() - shift; vIdx < vSize; ++vIdx) { + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) + batchVars[bIdx] = &(batch_[bIdx]->variables[vIdx]); + inputs.emplace_back(std::make_pair(var_feed_names_[vIdx + shift], + Tensorflow::Tensor::concat(batchVars, 0))); + } + + session_.run(inputs, var_feed_ops_); +} + +// mainly label feedback +void TFModelBase::feedDecodeInput(MappedTensorList& inputs) { + for (u32 vIdx = 0, vSize = decoding_input_tensor_names_.size(); vIdx < vSize; ++vIdx) { + if (decoding_input_ndims_[vIdx] == 1) { // sparse + std::vector vec(batch_.size()); + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) + vec[bIdx] = batch_[bIdx]->labelSeq.back(); + inputs.emplace_back(std::make_pair(var_feed_names_[vIdx], Tensorflow::Tensor::create(vec))); + } + else if (decoding_input_ndims_[vIdx] == 2) { + u32 len = 1; // Note: no multi-step feedback yet + Math::FastMatrix mat(batch_.size(), len); + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) { + // Note: no mask handling, all has to be evaluated for len + verify(batch_[bIdx]->labelSeq.size() >= len); + u32 idx = batch_[bIdx]->labelSeq.size() - len; + for (u32 tIdx = 0; tIdx < len; ++tIdx) + mat.at(bIdx, tIdx) = batch_[bIdx]->labelSeq[idx + tIdx]; + } + inputs.emplace_back(std::make_pair(var_feed_names_[vIdx], Tensorflow::Tensor::create(mat))); + } + else { + criticalError() << "unsupported ndims " << decoding_input_ndims_[vIdx] + << " of decoding input tensor " << decoding_input_tensor_names_[vIdx]; + } + } +} + +void TFModelBase::updateBatchVariables(bool post) { + if (post) { + if (!var_post_update_ops_.empty()) + session_.run({}, var_post_update_ops_); + } + else { + if (!var_update_ops_.empty()) + session_.run({}, var_update_ops_); + } +} + +void TFModelBase::fetchBatchVariables() { + if (var_fetch_names_.empty()) + return; + + TensorList outputs; + session_.run({}, var_fetch_names_, {}, outputs); + verify(batch_[0]->variables.size() == outputs.size()); + + // slice along the batch dim (inclusive) + for (u32 vIdx = 0, vSize = var_fetch_names_.size(); vIdx < vSize; ++vIdx) + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) + batch_[bIdx]->variables[vIdx] = outputs[vIdx].slice({bIdx}, {bIdx + 1}); +} + +// batch-wise score computation (also update states) +void TFModelBase::computeBatchScores() { + // base class only support single stochastic_var_scores (support multiple in derived classes) + verify(decoding_output_tensor_names_.size() == 1); + verify(decoding_ops_.size() == 1); + + // merge post update to the last scoring to avoid redundant computation + if (var_post_update_ops_.empty()) { + session_.run({}, decoding_ops_); + } + else { + std::vector merge_ops(decoding_ops_); + merge_ops.insert(merge_ops.end(), var_post_update_ops_.begin(), var_post_update_ops_.end()); + session_.run({}, merge_ops); + } + + // fetch scores + TensorList outputs; + session_.run({}, decoding_output_tensor_names_, {}, outputs); + verify(outputs.size() == 1); + processBatchOutput(outputs); + + // optional adding static log priors + if (usePrior_) + addPriorToBatch(); +} + +// assign scores to batch +void TFModelBase::processBatchOutput(const TensorList& outputs) { + if (debug_) { + std::vector fetchNames; + for (const std::string& s : graph_->decoder_input_vars()) { + const auto& var = graph_->getVariable(s); + fetchNames.push_back(var.snapshot_name); + } + fetchNames.insert(fetchNames.end(), var_fetch_names_.begin(), var_fetch_names_.end()); + fetchNames.insert(fetchNames.end(), decoding_output_tensor_names_.begin(), decoding_output_tensor_names_.end()); + debugFetch(fetchNames, "processBatchOutput"); + } + + u32 len = 1; // no multi-step computation + bool spacial = decoding_output_ndims_.front() == 3; + verify_(spacial || decoding_output_ndims_.front() == 2); + + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) { + // scores always first + LabelHistoryDescriptor* lhd = batch_[bIdx]; + if (spacial) + outputs[0].get(bIdx, len - 1, lhd->scores); + else + outputs[0].get(bIdx, lhd->scores); + if (decoding_output_transform_function_) + std::transform(lhd->scores.begin(), lhd->scores.end(), lhd->scores.begin(), + std::bind(decoding_output_transform_function_, std::placeholders::_1, scale_)); + } +} + +void TFModelBase::addPriorToBatch() { + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) { + LabelHistoryDescriptor* lhd = batch_[bIdx]; + if (priorContextSize_ == 0) { // context-independent prior + std::transform(logPriors_.begin(), logPriors_.end(), lhd->scores.begin(), lhd->scores.begin(), std::plus()); + } + else { // (truncated) context-dependent prior + size_t hash = labelHistoryManager_->reducedHashKey(lhd, priorContextSize_); + ScoreCache::iterator iter = contextLogPriors_.find(hash); + verify(iter != contextLogPriors_.end()); + std::transform(iter->second.begin(), iter->second.end(), lhd->scores.begin(), + lhd->scores.begin(), std::plus()); + } + } +} + +// -------------- debug: check related tensor ---------------- +void TFModelBase::debugFetch(const std::vector& fetchNames, std::string msg) { + std::cout << "# " << msg << " ==> debug check batch_size=" << batch_.size() << std::endl; + if (fetchNames.empty()) + return; + + TensorList outputs; + session_.run({}, fetchNames, {}, outputs); + for (u32 idx = 0; idx < fetchNames.size(); ++idx) { + // shape and scalar value + std::cout << " " << fetchNames[idx] << " " << outputs[idx].dimInfo(); + if (outputs[idx].numDims() == 0) { + s32 v; + outputs[idx].get(v); + std::cout << " value=" << v; + } + std::cout << std::endl; + } +} +// ---------------------------------------------------------- + +// --- RNN Transducer --- +const Core::ParameterBool TFRnnTransducer::paramLoopFeedbackAsBlank( + "loop-feedback-as-blank", + "label loop feedback as blank (mainly for masked computation to skip certain computation in the graph)", + false); + +const Core::ParameterBool TFRnnTransducer::paramVerticalTransition( + "use-vertical-transition", + "standard RNNT topology with veritical transition, otherwise strictly-monotonic", + false); + +TFRnnTransducer::TFRnnTransducer(const Core::Configuration& config) + : Core::Component(config), + Precursor(config), + loopFeedbackAsBlank_(paramLoopFeedbackAsBlank(config)), + verticalTransition_(paramVerticalTransition(config)) { + blankLabelIndex_ = getBlankLabelIndex(); + if (blankLabelIndex_ == Core::Type::max) + warning() << "no blank label for rnn transducer, assuming posterior HMM"; + else if (blankUpdateHistory_) + log() << "blank label updates history"; + + // topology variants with label loop + if (loopUpdateHistory_) + log() << "label loop updates history"; + else if (loopFeedbackAsBlank_) + log() << "treat label loop feedback as blank"; + + if (verticalTransition_) { // standard RNN-T topology + verify(blankLabelIndex_ != Core::Type::max); + verify(global_var_feed_names_.empty()); + startPosition_ = 0; + needEndProcessing_ = true; + log() << "use veritical transition"; + } + else { // strictly monotonic RNN-T topology (RNA topology) + // position (decodeStep_) starts at 0: distinguish startHistory with first blank + startPosition_ = -1; + } +} + +// either globally set the encoding position once for all at each decode step +// or empty global_vars: each history has its own position state_var in the graph +// model graph should have the topology-dependent update scheme -> update_ops based on feedback +// TODO streaming case where clearBuffer reset decodeStep_: mismatch with encodings ? +void TFRnnTransducer::increaseDecodeStep() { + Precursor::increaseDecodeStep(); + if (!global_var_feed_names_.empty()) { + verify(global_var_feed_names_.size() == 1); + if (!isPositionDependent_) + setDecodePosition(decodeStep_); + } +} + +// set global position of encodings to the next step (time synchronous) +// called after each decoding step (position 0 is initialized via encoding_ops_) +void TFRnnTransducer::setDecodePosition(u32 pos) { + MappedTensorList inputs; + inputs.emplace_back(std::make_pair(global_var_feed_names_[0], Tensorflow::Tensor::create(s32(pos)))); + session_.run(inputs, global_var_feed_ops_); +} + +// history extension and position update based on topology +// cacheHash depends on both label history and position +// additional special blank status to feed in blank label for next computation +void TFRnnTransducer::extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop) { + // position updated by search if vertical transition or segmental decoding + // otherwise use the global decode step + // for simplicity: so far we don't link this position with state_var if existing, + // but expect that the model graph has a equivalent update scheme (topology) + if (!verticalTransition_ && !isPositionDependent_) + position = decodeStep_; + + // output forward or alignment sequence dependency (blank or loop update history) + // update label and states for next computation as usual + if ((idx != blankLabelIndex_ || blankUpdateHistory_) && (!isLoop || loopUpdateHistory_)) { + Precursor::extendLabelHistory(h, idx, position, isLoop); + return; + } + + // blank or loop, but output sequence dependency + // still create new history at this new position for scoring (also update states if needed) + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + CacheUpdateResult result = labelHistoryManager_->checkCache(lhd, position); + LabelHistoryDescriptor* nlhd; + if (result.second) { // existing one + // enusre no hash colision w.r.t. position + verify_(labelHistoryManager_->isEqualSequence(lhd, result.first->second)); + nlhd = static_cast(result.first->second); + } + else { // create new (keep parent's states for next computation) and activate blank status + nlhd = new LabelHistoryDescriptor(*lhd); + if (isLoop && !loopFeedbackAsBlank_) + nlhd->isBlank = false; + else + nlhd->isBlank = true; + nlhd->scores.clear(); + nlhd->position = position; + + result = labelHistoryManager_->updateCache(nlhd, position); + if (result.second) { + // caching newly extended label history for batch scoring + cacheHashQueue_.push_back(nlhd->cacheHash); + } + else { // this should not happen ! + verify_(labelHistoryManager_->isEqualSequence(nlhd, result.first->second)); + delete nlhd; + nlhd = static_cast(result.first->second); + } + } + h = labelHistoryManager_->history(nlhd); +} + +// always one time-step (sparse) +void TFRnnTransducer::feedDecodeInput(MappedTensorList& inputs) { + for (u32 vIdx = 0, vSize = decoding_input_tensor_names_.size(); vIdx < vSize; ++vIdx) { + verify(decoding_input_ndims_[vIdx] == 1); + std::vector vec(batch_.size()); + for (u32 bIdx = 0, bSize = batch_.size(); bIdx < bSize; ++bIdx) { + if (batch_[bIdx]->isBlank) { + // feed in blank to skip certain computation (graph must be aware), loop for posterior HMM + if (blankLabelIndex_ == Core::Type::max) + vec[bIdx] = batch_[bIdx]->labelSeq.back() + numClasses_; + else + vec[bIdx] = blankLabelIndex_; + } + else + vec[bIdx] = batch_[bIdx]->labelSeq.back(); + } + inputs.emplace_back(std::make_pair(var_feed_names_[vIdx], Tensorflow::Tensor::create(vec))); + } +} + +// --- FFNN Transducer --- +const Core::ParameterInt TFFfnnTransducer::paramContextSize( + "context-size", + "label context size (min 1: otherwise use precomputed label scorer)", + 1, 1); + +const Core::ParameterBool TFFfnnTransducer::paramCacheHistory( + "cache-history", + "cache appeared ngram history to avoid redundant computation (memory for high order !)", + true); + +// HMM-topology: implicit transition +const Core::ParameterBool TFFfnnTransducer::paramImplicitTransition( + "implicit-transition", + "derived implicit transition from label posterior: p(forward) = 1 - p(loop)", + false); + +// HMM-topology: explicit transition +const Core::ParameterBool TFFfnnTransducer::paramExplicitTransition( + "explicit-transition", + "explicit transition modeling: p(loop) appended as the last score element (|V|+1)", + false); + +const Core::ParameterBool TFFfnnTransducer::paramRenormTransition( + "renorm-transition", + "renormalize model over forward+loop (only for explicit-transition)", + true); + +const Core::ParameterBool TFFfnnTransducer::paramUseRelativePosition( + "use-relative-position", + "use (1st order) relative-position dependency", + false); + +TFFfnnTransducer::TFFfnnTransducer(Core::Configuration const& config) + : Core::Component(config), + Precursor(config), + contextSize_(paramContextSize(config)), + cacheHistory_(paramCacheHistory(config)), + implicitTransition_(paramImplicitTransition(config)), + explicitTransition_(paramExplicitTransition(config)), + renormTransition_(paramRenormTransition(config)), + useRelativePosition_(paramUseRelativePosition(config)) { + log() << "feedforward neural transducer with label context size " << contextSize_; + log() << "Note: decoder_input_vars order must be oldest first"; // add code to verify ? + if (cacheHistory_) + log() << "apply history caching (memory for high order !)"; + verify(startPosition_ == 0); + + blankLabelIndex_ = getBlankLabelIndex(); + hmmTopology_ = blankLabelIndex_ == Core::Type::max; + if (!hmmTopology_) { + log() << "RNA topology with blank label index " << blankLabelIndex_; + if (blankUpdateHistory_) + log() << "blank label updates history"; + else + log() << "blank label does not updates history"; + } + else { // loop and blank is mutual exclusive so far + log() << "HMM topology: label loop without blank"; + verify(!useRelativePosition_); + if (isPositionDependent_) + criticalError() << "segmental scoring for HMM topology not supported yet !"; + if (loopUpdateHistory_) { + verify(!isPositionDependent_); // can't be segmental + log() << "label loop updates history"; + } + else { + log() << "label loop does not update history"; + } + } + + if (implicitTransition_ || explicitTransition_) { + verify(hmmTopology_ && !loopUpdateHistory_); + verify(!(implicitTransition_ && explicitTransition_)); + if (usePrior_) // TODO need to separate + criticalError() << "implicit/explicit transition + prior not supported yet"; + if (implicitTransition_) { + log() << "apply implicit transition derived from label posterior"; + } + else if (explicitTransition_) { + log() << "apply explicit transition from the model (last score element for loop)"; + if (renormTransition_) + log() << "renormalize model over forward+loop"; + } + } + + // size check + u32 nInput = decoding_input_tensor_names_.size(); + if (useRelativePosition_) { + verify(nInput == contextSize_ + 1); // also relative position + verify(!blankUpdateHistory_); + verify(!isPositionDependent_); // not explicit segmental + log() << "use first order relative position"; + } + else { + verify(nInput == contextSize_); + } + + for (u32 vIdx = 0; vIdx < nInput; ++vIdx) + verify(decoding_input_ndims_[vIdx] == 1); // all scalars + verify(var_feed_ops_.size() == nInput); // there should be no hidden states + verify(decoding_ops_.size() == 1); + verify(decoding_output_tensor_names_.size() == 1); + verify(decoding_output_ndims_[0] == 2); +} + +TFFfnnTransducer::~TFFfnnTransducer() { + if (cacheHistory_) { + // free cache expicitly + const HistoryCache cache = labelHistoryManager_->historyCache(); + for (HistoryCache::const_iterator iter = cache.begin(); iter != cache.end(); ++iter) + delete iter->second; + labelHistoryManager_->reset(); + } +} + +void TFFfnnTransducer::reset() { + inputBuffer_.clear(); + nInput_ = 0; + eos_ = false; + decodeStep_ = 0; + + scoreCache_.clear(); + batchHashQueue_.clear(); + batchHash_.clear(); + scoreTransitionCache_.clear(); + positionScoreCache_.clear(); + + if (!cacheHistory_) { + labelSeqCache_.clear(); + labelHistoryManager_->reset(); + } +} + +void TFFfnnTransducer::cleanUpBeforeExtension(u32 minPos) { + scoreCache_.clear(); + batchHashQueue_.clear(); + scoreTransitionCache_.clear(); + + if (isPositionDependent_) { + // cache clean up w.r.t min position among all hypotheses (otherwise memory expensive ?) + for (std::pair& kv : positionScoreCache_) + if (kv.first < minPos) + kv.second.clear(); + } +} + +LabelHistory TFFfnnTransducer::startHistory() { + LabelHistoryDescriptor* lhd = new LabelHistoryDescriptor(); + if (hmmTopology_ & !loopUpdateHistory_) // keep previous segment label for loop history + lhd->labelSeq.resize(contextSize_ + 1, startLabelIndex_); + else + lhd->labelSeq.resize(contextSize_, startLabelIndex_); + + CacheUpdateResult result = labelHistoryManager_->updateCache(lhd, startPosition_); + if (!result.second) { + delete lhd; + lhd = static_cast(result.first->second); + } + else { + if (cacheHistory_) + lhd->ref_count += 1; // always kept in cache + if (hmmTopology_ & !loopUpdateHistory_) { + LabelSequence labelSeq(contextSize_, startLabelIndex_); + lhd->forwardHash = label_sequence_hash(labelSeq); + lhd->loopHash = lhd->forwardHash; + labelSeqCache_.insert(std::make_pair(lhd->forwardHash, labelSeq)); + } + } + if (decodeStep_ == 0) { + if (hmmTopology_ & !loopUpdateHistory_) + batchHashQueue_.insert(lhd->forwardHash); + else + batchHashQueue_.insert(lhd->cacheHash); + } + return labelHistoryManager_->history(lhd); +} + +// need further speed up ? +void TFFfnnTransducer::extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop) { + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + LabelHistoryDescriptor* nlhd; + + if (!useRelativePosition_) { + if (idx == blankLabelIndex_ && !blankUpdateHistory_) { + // RNA topology: blank does not update history and no loop + batchHashQueue_.insert(lhd->cacheHash); + return; + } + else if (hmmTopology_ && !loopUpdateHistory_ && isLoop) { + // HMM topology: loop does not update history and no blank + batchHashQueue_.insert(lhd->forwardHash); + batchHashQueue_.insert(lhd->loopHash); + return; + } + // unless relative position: history cache is only label-seq dependent + position = 0; + nlhd = new LabelHistoryDescriptor(lhd->labelSeq, idx); + } + else { + // position-aware ffnn-transducer: only for RNA topology + // cache hash: both label-seq and rel-position dependent + if (idx == blankLabelIndex_) + nlhd = new LabelHistoryDescriptor(*lhd); + else + nlhd = new LabelHistoryDescriptor(lhd->labelSeq, idx); + nlhd->position = position; + } + + CacheUpdateResult result = labelHistoryManager_->updateCache(nlhd, position); + if (!result.second) { + delete nlhd; + nlhd = static_cast(result.first->second); + } + else { // new one: compute hash and cache label sequence + if (cacheHistory_) + nlhd->ref_count += 1; // always kept in cache + if (hmmTopology_ & !loopUpdateHistory_) { + LabelSequence fSeq(nlhd->labelSeq.begin() + 1, nlhd->labelSeq.end()); + LabelSequence lSeq(nlhd->labelSeq.begin(), nlhd->labelSeq.end() - 1); + nlhd->forwardHash = label_sequence_hash(fSeq); + nlhd->loopHash = label_sequence_hash(lSeq); + labelSeqCache_.insert(std::make_pair(nlhd->forwardHash, fSeq)); + labelSeqCache_.insert(std::make_pair(nlhd->loopHash, lSeq)); + } + } + + if (hmmTopology_ & !loopUpdateHistory_) { + batchHashQueue_.insert(nlhd->forwardHash); + if (!isPositionDependent_) + batchHashQueue_.insert(nlhd->loopHash); + } + else { + batchHashQueue_.insert(nlhd->cacheHash); + } + h = labelHistoryManager_->history(nlhd); +} + +// set global position of encodings to the next step (time synchronous) +// called after each decoding step (position 0 is initialized via encoding_ops_) +void TFFfnnTransducer::increaseDecodeStep() { + Precursor::increaseDecodeStep(); + verify(global_var_feed_names_.size() == 1); + if (!isPositionDependent_) + setDecodePosition(decodeStep_); +} + +void TFFfnnTransducer::setDecodePosition(u32 pos) { + MappedTensorList inputs; + inputs.emplace_back(std::make_pair(global_var_feed_names_[0], Tensorflow::Tensor::create(s32(pos)))); + session_.run(inputs, global_var_feed_ops_); +} + +const std::vector& TFFfnnTransducer::getScores(const LabelHistory& h, bool isLoop) { + // hmmTopology_ && !loopUpdateHistory_: special handling to include transition scores + // p(forward) = 1 at the first frame (decodeStep_ = 0) + if (explicitTransition_ || (implicitTransition_ && !isLoop && decodeStep_ > 0)) + return getScoresWithTransition(h, isLoop); + + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + size_t hash; + if (hmmTopology_ && !loopUpdateHistory_) { + // segment label dependent scoring: differs for loop and forward + hash = isLoop ? lhd->loopHash : lhd->forwardHash; + } + else { + hash = lhd->cacheHash; + } + const std::vector& scores = scoreCache_[hash]; + if (!scores.empty()) + return scores; + + // batch computation + makeBatch(lhd); + verify(batchHash_.size() > 0); + decodeBatch(scoreCache_); + + // results + verify(!scores.empty()); + return scores; +} + +void TFFfnnTransducer::makeBatch(LabelHistoryDescriptor* targetLhd) { + if (hmmTopology_ && !loopUpdateHistory_) { + if (batchHashQueue_.erase(targetLhd->forwardHash) > 0) + batchHash_.push_back(targetLhd->forwardHash); + if (batchHashQueue_.erase(targetLhd->loopHash) > 0) + batchHash_.push_back(targetLhd->loopHash); + } + else if (batchHashQueue_.erase(targetLhd->cacheHash) > 0) + batchHash_.push_back(targetLhd->cacheHash); + + std::unordered_set::const_iterator iter = batchHashQueue_.begin(); + while (batchHash_.size() < maxBatchSize_ && iter != batchHashQueue_.end()) + batchHash_.push_back(*(iter++)); + batchHashQueue_.erase(batchHashQueue_.begin(), iter); +} + +void TFFfnnTransducer::decodeBatch(ScoreCache& scoreCache) { + // feed in label context: left to right (right-most latest) + MappedTensorList inputs; + std::vector> vecs(contextSize_, std::vector(batchHash_.size())); + u32 offset = 0; + if (hmmTopology_ && !loopUpdateHistory_) { + for (u32 bIdx = 0, bSize = batchHash_.size(); bIdx < bSize; ++bIdx) { + const LabelSequence& seq = labelSeqCache_[batchHash_[bIdx]]; + for (u32 vIdx = 0; vIdx < contextSize_; ++vIdx) + vecs[vIdx][bIdx] = seq[vIdx]; + } + } + else { + const HistoryCache& cache = labelHistoryManager_->historyCache(); + std::vector pos(batchHash_.size()); // optional first-order relative position + for (u32 bIdx = 0, bSize = batchHash_.size(); bIdx < bSize; ++bIdx) { + LabelHistoryDescriptor* lhd = static_cast(cache.at(batchHash_[bIdx])); + for (u32 vIdx = 0; vIdx < contextSize_; ++vIdx) + vecs[vIdx][bIdx] = lhd->labelSeq[vIdx]; + pos[bIdx] = lhd->position; + } + if (useRelativePosition_) { + inputs.emplace_back(std::make_pair(var_feed_names_[0], Tensorflow::Tensor::create(pos))); + offset = 1; // first input is always relative position + } + pos.clear(); + } + for (u32 vIdx = 0; vIdx < contextSize_; ++vIdx) + inputs.emplace_back(std::make_pair(var_feed_names_[vIdx + offset], Tensorflow::Tensor::create(vecs[vIdx]))); + vecs.clear(); + + session_.run(inputs, var_feed_ops_); + updateBatchVariables(); + computeBatchScores(scoreCache); + batchHash_.clear(); +} + +void TFFfnnTransducer::computeBatchScores(ScoreCache& scoreCache) { + // compute batch scores (optional prior) + session_.run({}, decoding_ops_); + TensorList outputs; + session_.run({}, decoding_output_tensor_names_, {}, outputs); + verify(outputs.size() == 1); + + for (u32 bIdx = 0, bSize = batchHash_.size(); bIdx < bSize; ++bIdx) { + // cache score to reuse + std::vector& score = scoreCache[batchHash_[bIdx]]; + verify(score.empty()); + outputs[0].get(bIdx, score); + + // -scale * log(posterior) + if (decoding_output_transform_function_) + std::transform(score.begin(), score.end(), score.begin(), + std::bind(decoding_output_transform_function_, std::placeholders::_1, scale_)); + + // optional adding static log priors + if (usePrior_) { + if (priorContextSize_ == 0) { // context-independent prior + std::transform(logPriors_.begin(), logPriors_.end(), score.begin(), score.begin(), + std::plus()); + } + else { // (truncated) context-dependent prior + size_t hash; + if (hmmTopology_ && !loopUpdateHistory_) { + const LabelSequence& seq = labelSeqCache_[batchHash_[bIdx]]; + hash = labelHistoryManager_->reducedHashKey(seq, priorContextSize_); + } + else { + const LabelSequence& seq = labelHistoryManager_->historyCache().at(batchHash_[bIdx])->labelSeq; + hash = labelHistoryManager_->reducedHashKey(seq, priorContextSize_); + } + ScoreCache::iterator iter = contextLogPriors_.find(hash); + verify(iter != contextLogPriors_.end()); + std::transform(iter->second.begin(), iter->second.end(), score.begin(), score.begin(), + std::plus()); + } + } + } +} + +// Transducer w/o blank - HMM topology: p(label|...) p(transition|...) +const std::vector& TFFfnnTransducer::getScoresWithTransition(const LabelHistory& h, bool isLoop) { + // need both forward and loop scores + // cacheHash defines the label sequence, thus everything + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + std::vector& scores = scoreTransitionCache_[lhd->cacheHash]; + if (!scores.empty()) + return scores; + + const std::vector& forwardScores = scoreCache_[lhd->forwardHash]; + const std::vector& loopScores = scoreCache_[lhd->loopHash]; + if (forwardScores.empty() || loopScores.empty()) { + // batch computation + makeBatch(lhd); + verify(batchHash_.size() > 0); + decodeBatch(scoreCache_); + } + + if (implicitTransition_) { + // e.g. p(y_t | a_{s_t - 1}, h_1^T) only + // - forward transition scores at segment begin + // - derived from label posterior p(forward) = 1 - p(loop_label) + verify(forwardScores.size() == numClasses_ && loopScores.size() == numClasses_); + scores.resize(numClasses_, 0); + Score forward = getExclusiveScore(loopScores.at(lhd->labelSeq.back())); + std::transform(forwardScores.begin(), forwardScores.end(), scores.begin(), + std::bind(std::plus(), std::placeholders::_1, forward)); + } + else { // explicitTransition_ + // e.g. p(y_t | a_{s_t - 1}, h_1^T) * p(delta_t | y_{t-1}, h_1^T) + // - transition score at each frame: |V|+1 -th output for p(loop | y_{t-1}, h_1^T) + // - forward: y_{t-1} = a_{s_t - 1} only feed forwardHash needed + // => p(y_t | a_{s_t - 1}, h_1^T) * p(forward) = 1 - p(loop) + // - loop: feed loopHash for p(y_t=y_{t-1}| ...) + // => p(y_t=y_{t-1} | a_{s_t - 1}, h_1^T) * p(loop) + // put all to model graph ? then a lot of redundant computation + + // appended ILM for forward labels only + bool forwardILM = forwardScores.size() == 2 * numClasses_ + 1; + if (forwardILM) + verify(loopScores.size() == 2 * numClasses_ + 1); + else + verify(forwardScores.size() == numClasses_ + 1 && loopScores.size() == numClasses_ + 1); + + scores.resize(numClasses_ + 1, 0); + Score loop = forwardScores.at(numClasses_); + Score forward = getExclusiveScore(loop); + std::transform(forwardScores.begin(), forwardScores.begin() + numClasses_, scores.begin(), + std::bind(std::plus(), std::placeholders::_1, forward)); + + if (decodeStep_ > 0) + scores.back() = loopScores.at(lhd->labelSeq.back()) + loop; + else + scores.back() = Core::Type::max; // no loop for the 1st frame + + // optional renormalization over forward + loop + if (renormTransition_) { + Score sum = computeScoreSum(scores); + std::transform(scores.begin(), scores.end(), scores.begin(), + std::bind(std::plus(), std::placeholders::_1, -sum)); + } + // ILM on output sequence level: all forward positions + if (forwardILM) + std::transform(scores.begin(), scores.end() - 1, forwardScores.begin() + numClasses_ + 1, + scores.begin(), std::minus()); + } + return scores; +} + +// -scale * log(p) => -scale * log(1 - p) +Score TFFfnnTransducer::getExclusiveScore(Score score) { + // note: possible nan or inf when use prior + return -scale_ * std::log1p(-std::exp(score / (-scale_))); +} + +// label-sync segmental decoding (expensive) +// RNA topology only: equivalence of segmental and transducer modeling +const SegmentScore& TFFfnnTransducer::getSegmentScores(const LabelHistory& h, LabelIndex segIdx, u32 startPos) { + verify(isPositionDependent_); + segmentScore_.clear(); + + u32 totalLen = getEncoderLength() - 1; + verify(totalLen >= startPos); + u32 remainLen = totalLen - startPos; + if (remainLen < minSegLen_) + return segmentScore_; // empty + + LabelHistoryDescriptor* lhd = static_cast(h.handle()); + size_t hash = lhd->cacheHash; + u32 maxLen = std::min(remainLen, maxSegLen_); + u32 minLen = std::min(u32(1), minSegLen_); // 0-frame segment also possible + + Score score = 0; + for (u32 len = minLen; len <= maxLen; ++len) { + u32 pos = startPos + len - 1; + const std::vector& scores = getPositionScores(hash, pos); + // regard label peak as segment end for scoring (simplicity: same history) + if (len >= minSegLen_) + segmentScore_.push_back(std::make_pair(len, score + scores[segIdx])); + score += scores[blankLabelIndex_]; + } + + return segmentScore_; +} + +const std::vector& TFFfnnTransducer::getPositionScores(size_t hash, u32 pos) { + ScoreCache& scoreCache = positionScoreCache_[pos]; + const std::vector& scores = scoreCache[hash]; + if (scores.empty()) { + makePositionBatch(hash, scoreCache); + setDecodePosition(pos); + decodeBatch(scoreCache); + } + verify(!scores.empty()); + return scores; +} + +// input scoreCache is position dependent +void TFFfnnTransducer::makePositionBatch(size_t hash, const ScoreCache& scoreCache) { + verify(batchHashQueue_.count(hash) > 0); + batchHash_.push_back(hash); + + std::unordered_set::const_iterator iter = batchHashQueue_.begin(); + while (batchHash_.size() < maxBatchSize_ && iter != batchHashQueue_.end()) { + // target hash is already in scoreCache with empty scores + if (scoreCache.count(*iter) == 0) + batchHash_.push_back(*iter); + ++iter; + } + // Note: there might be a little waste of batch computation if at this step for this position, + // only a few context is remained for scoring, but a few more new context appear at the next step + // to be scored for this position (maybe only for low order context and only at beginning ?) + // For higher order context, leave it as on demand + if (decodeStep_ > 0 && contextSize_ == 1 && batchHash_.size() < maxBatchSize_ / 2) { + // also cacheHash ? anyway not major use case + LabelSeqCache::const_iterator iter = labelSeqCache_.begin(); + while (batchHash_.size() < maxBatchSize_ && iter != labelSeqCache_.end()) { + // fill other possible context + if (batchHashQueue_.count(iter->first) == 0 && scoreCache.count(iter->first) == 0) + batchHash_.push_back(iter->first); + ++iter; + } + } +} + +// --- Segmental Model --- +/* +TFSegmentalModel::TFSegmentalModel(Core::Configuration const& config): + Core::Component(config), + Precursor(config) +{ + needEndProcessing_ = true; +} +*/ diff --git a/src/Nn/TFLabelScorer.hh b/src/Nn/TFLabelScorer.hh index 76d20ffb..e6d17755 100644 --- a/src/Nn/TFLabelScorer.hh +++ b/src/Nn/TFLabelScorer.hh @@ -16,12 +16,300 @@ #ifndef TF_LABEL_SCORER_HH #define TF_LABEL_SCORER_HH +#include +#include +#include +#include +#include #include "LabelScorer.hh" namespace Nn { - -} // namesapce +typedef std::vector TensorList; +typedef std::vector> MappedTensorList; -#endif +struct TFLabelHistory : public LabelHistoryBase { + std::vector scores; + TensorList variables; + u32 position; + bool isBlank; // for next feedback + + typedef LabelHistoryBase Precursor; + + TFLabelHistory() + : Precursor(), position(0), isBlank(false) {} + TFLabelHistory(const TFLabelHistory& ref) + : Precursor(ref), scores(ref.scores), variables(ref.variables), position(ref.position), isBlank(ref.isBlank) {} +}; + +// Encoder-Decoder Label Scorer based on Tensorflow back-end +// computation logics based on a predefined order of I/O and op collections in graph +// prerequisite: model graph compilation that parse the model into these collections +class TFModelBase : public LabelScorer { + typedef LabelScorer Precursor; + +public: + // config params for graph computation + static const Core::ParameterBool paramTransformOuputLog; + static const Core::ParameterBool paramTransformOuputNegate; + static const Core::ParameterInt paramMaxBatchSize; + + // overwrite descriptor in derived class for specific history + typedef TFLabelHistory LabelHistoryDescriptor; + +public: + TFModelBase(const Core::Configuration& config); + virtual ~TFModelBase(); + + virtual void reset(); + virtual void cleanUpBeforeExtension(u32 minPos) { + cacheHashQueue_.clear(); + } + + // history handling + virtual LabelHistory startHistory(); + virtual void extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop); + + // encoding + virtual void encode(); + + // get scores for the next output position + virtual const std::vector& getScores(const LabelHistory& h, bool isLoop); + +protected: + void init(); + void initDecoder(); + void initStartHistory(); + void loadPrior(); + + // ---- batch-wise graph computation ---- + virtual void initComputation(); + + virtual void makeBatch(LabelHistoryDescriptor* targetLhd); + virtual void decodeBatch(); + + virtual void feedBatchVariables(); + virtual void feedDecodeInput(MappedTensorList& inputs); + virtual void updateBatchVariables(bool post = false); + virtual void fetchBatchVariables(); + + virtual void addPriorToBatch(); + virtual void computeBatchScores(); + virtual void processBatchOutput(const TensorList& outputs); + // -------------------------------------- + + bool debug_; + void debugFetch(const std::vector& fetchNames, std::string msg = ""); + +protected: + // Note: graph related params follow snake_case naming style + mutable Tensorflow::Session session_; + std::unique_ptr loader_; + std::unique_ptr graph_; + + // --- encoder --- + std::string encoding_input_tensor_name_; + std::string encoding_input_seq_length_tensor_name_; + + // --- decoder --- + std::vector decoding_input_tensor_names_; + std::vector decoding_output_tensor_names_; + std::vector decoding_input_ndims_; + std::vector decoding_output_ndims_; + // binary function including scaling + std::function decoding_output_transform_function_; + + std::vector var_feed_names_; + std::vector var_feed_ops_; + std::vector var_fetch_names_; + + // --- step ops --- + std::vector encoding_ops_; + std::vector decoding_ops_; + std::vector var_update_ops_; + std::vector var_post_update_ops_; + + // --- global --- + std::vector global_var_feed_names_; + std::vector global_var_feed_ops_; + +protected: + LabelHistoryDescriptor* startHistoryDescriptor_; // only common stuff, no states or scores + + typedef std::vector Batch; + Batch batch_; + std::deque cacheHashQueue_; + u32 maxBatchSize_; + + typedef std::unordered_map> ScoreCache; + ScoreCache contextLogPriors_; +}; + +// Attention-based Encoder-Decoder Model +// attention mechanism only in model graph (soft/hard): no additional latent variable here +class TFAttentionModel : public TFModelBase { + typedef TFModelBase Precursor; + +public: + TFAttentionModel(const Core::Configuration& config) + : Core::Component(config), + Precursor(config) { + needEndProcessing_ = true; + } +}; + +// RNN-Transducer|Aligner +// - blank-based topology +// - strictly monotonic (time|alignment-sync search w.r.t. decodeStep_) +// - either global_var simplification for enc_position +// or empty global_var: each hyp has its own position state_var (always +1) +// - optional label loop: different score and history handling +// - vertical transition (alignment-sync search) +// - empty global_var: each hyp has its own position state_var (+1 for blank) +// - additional ending detection/processing based on position +// - non-blank based topology: HMM-like with label loop +// - feedback: always the last alignment label (masking done in the graph) +// - dependency (recombination) +// - default: output label sequence +// - optional include blanks (e.g. towards full alignment sequence) +// - optional include loops +class TFRnnTransducer : public TFModelBase { + typedef TFModelBase Precursor; +public: + static const Core::ParameterBool paramLoopFeedbackAsBlank; + static const Core::ParameterBool paramVerticalTransition; + +public: + TFRnnTransducer(const Core::Configuration& config); + + bool useVerticalTransition() const { + return verticalTransition_; + } + + void increaseDecodeStep(); + void extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop); + +protected: + void feedDecodeInput(MappedTensorList& inputs); + void setDecodePosition(u32 pos); + +private: + LabelIndex blankLabelIndex_; + bool loopFeedbackAsBlank_; + bool verticalTransition_; +}; + +// no state vars or scores: just label sequence and context hash +struct NgramLabelHistory : public LabelHistoryBase { + size_t forwardHash, loopHash; + u32 position; // only for position-aware ffnn-transducer + + typedef LabelHistoryBase Precursor; + + NgramLabelHistory() + : Precursor(), forwardHash(0), loopHash(0), position(0) {} + NgramLabelHistory(const NgramLabelHistory& ref) + : Precursor(ref), forwardHash(ref.forwardHash), loopHash(ref.loopHash), position(ref.position) {} + NgramLabelHistory(const LabelSequence& labSeq, LabelIndex nextIdx) + : Precursor(), forwardHash(0), loopHash(0), position(0) { + // always fixed context size (+1) and right-most latest + LabelSequence newSeq(labSeq.begin() + 1, labSeq.end()); + newSeq.push_back(nextIdx); + labelSeq.swap(newSeq); + } +}; + +// FFNN transducer with ngram context (no recurrency in decoder) +// - strictly monotonic topology only + global_var simplification for enc_position +// - both time-synchronous and label-synchronous search possible +// - latter: re-interpreted segmental decoding based on frame-wise output +// - label topology +// - either HMM-topology: loop without blank +// - or RNA-topology: blank without loop +// - dependency +// - output/segment label sequence or alignment sequence +// - additional first-order relative-position (so far only for RNA topology) +// Note: speed-up with context embedding lookup should be configured in the model graph +class TFFfnnTransducer : public TFModelBase { + typedef TFModelBase Precursor; + typedef NgramLabelHistory LabelHistoryDescriptor; + +public: + static const Core::ParameterInt paramContextSize; + static const Core::ParameterBool paramCacheHistory; + static const Core::ParameterBool paramImplicitTransition; + static const Core::ParameterBool paramExplicitTransition; + static const Core::ParameterBool paramUseRelativePosition; + static const Core::ParameterBool paramRenormTransition; + +public: + TFFfnnTransducer(Core::Configuration const& config); + ~TFFfnnTransducer(); + + void reset(); + void cleanUpBeforeExtension(u32 minPos); + + bool useRelativePosition() const { + return useRelativePosition_; + } + + // history handling + LabelHistory startHistory(); + void extendLabelHistory(LabelHistory& h, LabelIndex idx, u32 position, bool isLoop); + + // global position of encodings + void increaseDecodeStep(); + + // get label scores for the next output position + const std::vector& getScores(const LabelHistory& h, bool isLoop); + + // get segment scores for the next label segment given start position + const SegmentScore& getSegmentScores(const LabelHistory& h, LabelIndex segIdx, u32 startPos); + +protected: + void initComputation() {} + void makeBatch(LabelHistoryDescriptor* targetLhd); + void decodeBatch(ScoreCache& scoreCache); + void computeBatchScores(ScoreCache& scoreCache); + void setDecodePosition(u32 pos); + + const std::vector& getScoresWithTransition(const LabelHistory& h, bool isLoop); + Score getExclusiveScore(Score score); + + // for segmental decoding + const std::vector& getPositionScores(size_t hash, u32 pos); + void makePositionBatch(size_t hash, const ScoreCache& scoreCache); + +private: + u32 contextSize_; + bool cacheHistory_; + + // context (and position) dependent cache: central handling of scores instead of each history + ScoreCache scoreCache_; + std::unordered_set batchHashQueue_; + std::vector batchHash_; + + // HMM topology differs w.r.t. loopUpdateHistory_, if true then + // - alignment sequence dependency (otherwise output/segment label sequence) + // - loop scoring based on previous frame labels (otherwise segment labels) + bool hmmTopology_; + typedef std::unordered_map LabelSeqCache; + LabelSeqCache labelSeqCache_; // only for HMM topology: need clean up if not cacheHistory_ ? + ScoreCache scoreTransitionCache_; + bool implicitTransition_; + bool explicitTransition_; + bool renormTransition_; + + LabelIndex blankLabelIndex_; + bool useRelativePosition_; + + // for segmental decoding {position: {context: scores}} + std::unordered_map positionScoreCache_; +}; + +// TODO segmental model with explicit duration model ? + +} // namespace Nn + +#endif From 861c951190282b7dcb303c8a64ae8dc95e3cf829 Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Mon, 17 Apr 2023 15:43:01 +0200 Subject: [PATCH 2/6] Update LabelHistoryManager.hh Remove author line. --- src/Nn/LabelHistoryManager.hh | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh index 20167d77..9d4850ff 100644 --- a/src/Nn/LabelHistoryManager.hh +++ b/src/Nn/LabelHistoryManager.hh @@ -11,8 +11,6 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * author: Wei Zhou */ #ifndef LABEL_HISTORY_MANAGER_HH From a049f21d0627eda9fb52a4ec2d30b3c6eb0538aa Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Fri, 5 May 2023 12:06:11 +0200 Subject: [PATCH 3/6] Update src/Nn/LabelHistoryManager.hh Co-authored-by: Eugen Beck --- src/Nn/LabelHistoryManager.hh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh index 9d4850ff..432781fb 100644 --- a/src/Nn/LabelHistoryManager.hh +++ b/src/Nn/LabelHistoryManager.hh @@ -24,10 +24,12 @@ // Note: not 100% collision-free, better with additional safety where it's applied static size_t updateHashKey(size_t hash, size_t update) { // nothing to update - if (update == 0) + if (update == 0) { return hash; - if (hash == 0) + } + if (hash == 0) { return update; + } return hash ^ (update + 0x9e3779b9 + (hash << 6) + (hash >> 2)); } From 4bbcf10bf14dd49a4e7d72b2ebaa5abc01b5210c Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Fri, 5 May 2023 12:07:35 +0200 Subject: [PATCH 4/6] Update src/Nn/LabelHistoryManager.hh Co-authored-by: Eugen Beck --- src/Nn/LabelHistoryManager.hh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh index 432781fb..15b15138 100644 --- a/src/Nn/LabelHistoryManager.hh +++ b/src/Nn/LabelHistoryManager.hh @@ -118,8 +118,9 @@ public: : mang_(0), desc_(0) {} LabelHistory(const LabelHistory& ref) : mang_(ref.mang_), desc_(ref.desc_) { - if (desc_) + if (desc_) { mang_->acquire(desc_); + } } ~LabelHistory() { From 6361f53b372b78d6672a9ec024c6ae45a2d4460a Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Fri, 5 May 2023 12:08:04 +0200 Subject: [PATCH 5/6] Update src/Nn/LabelHistoryManager.hh Co-authored-by: Eugen Beck --- src/Nn/LabelHistoryManager.hh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh index 15b15138..e5d18602 100644 --- a/src/Nn/LabelHistoryManager.hh +++ b/src/Nn/LabelHistoryManager.hh @@ -167,8 +167,9 @@ private: }; inline LabelHistoryHandle LabelHistoryManager::acquire(LabelHistoryHandle lhd) const { - if (lhd) + if (lhd) { ++(lhd->ref_count); + } return lhd; } From b8565c401515feee900d658dbd78dd59aa5fb323 Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Fri, 5 May 2023 12:09:17 +0200 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Eugen Beck --- src/Nn/LabelHistoryManager.hh | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/Nn/LabelHistoryManager.hh b/src/Nn/LabelHistoryManager.hh index e5d18602..9e08625b 100644 --- a/src/Nn/LabelHistoryManager.hh +++ b/src/Nn/LabelHistoryManager.hh @@ -186,8 +186,9 @@ inline void LabelHistoryManager::release(LabelHistoryHandle lhd) const { } inline size_t LabelHistoryManager::reducedHashKey(const LabelSequence& labelSeq, s32 limit) const { - if (limit < 0 || (u32)limit >= labelSeq.size()) + if (limit < 0 || (u32)limit >= labelSeq.size()) { return label_sequence_hash(labelSeq); + } LabelSequence reducedLabelSeq(labelSeq.end() - limit, labelSeq.end()); return label_sequence_hash(reducedLabelSeq); } @@ -199,8 +200,9 @@ inline size_t LabelHistoryManager::extendedHashKey(const LabelHistoryHandle lhd, } inline size_t LabelHistoryManager::reducedExtendedHashKey(const LabelHistoryHandle lhd, s32 limit, LabelIndex lIdx) const { - if (limit < 0 || (u32)limit > lhd->labelSeq.size()) + if (limit < 0 || (u32)limit > lhd->labelSeq.size()) { return extendedHashKey(lhd, lIdx); + } LabelSequence reducedLabelSeq(lhd->labelSeq.end() - (limit - 1), lhd->labelSeq.end()); reducedLabelSeq.push_back(lIdx); return label_sequence_hash(reducedLabelSeq); @@ -231,45 +233,53 @@ inline CacheUpdateResult LabelHistoryManager::updateCache(LabelHistoryHandle lhd } inline const LabelHistory& LabelHistory::operator=(const LabelHistory& rhs) { - if (rhs.desc_) + if (rhs.desc_) { rhs.mang_->acquire(rhs.desc_); - if (desc_) + } + if (desc_) { mang_->release(desc_); + } mang_ = rhs.mang_; desc_ = rhs.desc_; return *this; } inline size_t LabelHistory::hashKey() const { - if (desc_) + if (desc_) { return mang_->hashKey(desc_); + } return 0; } inline size_t LabelHistory::reducedHashKey(s32 limit) const { - if (desc_ && limit != 0) + if (desc_ && limit != 0) { return mang_->reducedHashKey(desc_, limit); + } return 0; } inline size_t LabelHistory::reducedExtendedHashKey(s32 limit, LabelIndex lIdx) const { - if (desc_ && limit != 0) + if (desc_ && limit != 0) { return mang_->reducedExtendedHashKey(desc_, limit, lIdx); + } return 0; } inline LabelIndex LabelHistory::getLastLabel() const { - if (desc_ && !desc_->labelSeq.empty()) + if (desc_ && !desc_->labelSeq.empty()) { return desc_->labelSeq.back(); + } return Core::Type::max; } // debug inline void LabelHistory::format() const { std::cout << " LabelHistory: "; - if (desc_) - for (LabelIndex label : desc_->labelSeq) + if (desc_) { + for (LabelIndex label : desc_->labelSeq) { std::cout << label << " "; + } + } std::cout << std::endl; }