Skip to content

Commit

Permalink
Add BufferedLabelScorer (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimBe195 authored Jan 24, 2025
1 parent e08c2f9 commit bc4426b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/Nn/LabelScorer/BufferedLabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/** Copyright 2025 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "BufferedLabelScorer.hh"

namespace Nn {

BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config)
: Core::Component(config),
Precursor(config),
inputBuffer_(),
featureSize_(Core::Type<size_t>::max),
expectMoreFeatures_(true) {
}

void BufferedLabelScorer::reset() {
inputBuffer_.clear();
featureSize_ = Core::Type<size_t>::max;
expectMoreFeatures_ = true;
}

void BufferedLabelScorer::signalNoMoreFeatures() {
expectMoreFeatures_ = false;
}

void BufferedLabelScorer::addInput(std::shared_ptr<const f32[]> const& input, size_t featureSize) {
if (featureSize_ == Core::Type<size_t>::max) {
featureSize_ = featureSize;
}
else if (featureSize_ != featureSize) {
error() << "Label scorer received incompatible feature size " << featureSize << "; was set to " << featureSize_ << " before.";
}

inputBuffer_.push_back(input);
}

} // namespace Nn
52 changes: 52 additions & 0 deletions src/Nn/LabelScorer/BufferedLabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/** Copyright 2025 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef BUFFERED_LABEL_SCORER_HH
#define BUFFERED_LABEL_SCORER_HH

#include "LabelScorer.hh"

namespace Nn {

/*
* Extension of `LabelScorer` that implements some commonly used buffering logic for input features
* and timeframes as well as a flag that indicates whether more features are expected to be added to the buffer.
* This serves as a base class for other LabelScorers.
*/
class BufferedLabelScorer : public LabelScorer {
public:
using Precursor = LabelScorer;

BufferedLabelScorer(Core::Configuration const& config);

// Prepares the LabelScorer to receive new inputs by resetting input buffer, timeframe buffer
// and segment end flag
virtual void reset() override;

// Tells the LabelScorer that there will be no more input features coming in the current segment
virtual void signalNoMoreFeatures() override;

// Add a single input feature to the buffer
virtual void addInput(std::shared_ptr<const f32[]> const& input, size_t featureSize) override;

protected:
std::vector<std::shared_ptr<const f32[]>> inputBuffer_; // Buffer that contains all the feature data for the current segment
size_t featureSize_; // Feature dimension size of features in the buffer (same for all features)
bool expectMoreFeatures_; // Flag to record segment end signal
};

} // namespace Nn

#endif // BUFFERED_LABEL_SCORER_HH
1 change: 1 addition & 0 deletions src/Nn/LabelScorer/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ SUBDIRS =
TARGETS = libSprintLabelScorer.$(a)

LIBSPRINTLABELSCORER_O = \
$(OBJDIR)/BufferedLabelScorer.o \
$(OBJDIR)/LabelScorer.o \
$(OBJDIR)/ScoringContext.o

Expand Down

0 comments on commit bc4426b

Please sign in to comment.