Skip to content

Commit

Permalink
sort hypotheses
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Sep 25, 2024
1 parent b9842b1 commit 1fca6c9
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ namespace ctranslate2 {
static inline void sort_hypotheses(DecodingResult& result,
size_t max_hypotheses,
bool keep_scores,
bool keep_attention) {
bool keep_attention,
bool keep_logits_vocab) {
std::vector<size_t> idx(result.hypotheses.size());
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(),
Expand All @@ -226,14 +227,20 @@ namespace ctranslate2 {
result.attention = index_vector(result.attention, idx);
else
result.attention.clear();

if (keep_logits_vocab)
result.logits_vocab = index_vector(result.logits_vocab, idx);
else
result.logits_vocab.clear();
}

static inline void finalize_result(DecodingResult& result,
const size_t max_hypotheses,
const float length_penalty,
const float coverage_penalty,
const bool keep_scores,
const bool keep_attention) {
const bool keep_attention,
const bool keep_logits_vocab) {
for (size_t i = 0; i < result.scores.size(); ++i) {
const auto* attention = result.attention.empty() ? nullptr : &result.attention[i];
result.scores[i] = finalize_hypothesis_score(result.scores[i],
Expand All @@ -243,7 +250,7 @@ namespace ctranslate2 {
attention);
}

sort_hypotheses(result, max_hypotheses, keep_scores, keep_attention);
sort_hypotheses(result, max_hypotheses, keep_scores, keep_attention, keep_logits_vocab);
}

BiasedDecoder::BiasedDecoder(const float prefix_bias_beta,
Expand Down Expand Up @@ -651,7 +658,8 @@ namespace ctranslate2 {
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
return_attention,
return_logits_vocab);
} else {
non_finished_index.emplace_back(i);
}
Expand Down Expand Up @@ -796,7 +804,7 @@ namespace ctranslate2 {
}

for (auto& result : final_results)
sort_hypotheses(result, num_hypotheses, return_scores, return_attention);
sort_hypotheses(result, num_hypotheses, return_scores, return_attention, return_logits_vocab);

return final_results;
}
Expand Down Expand Up @@ -932,7 +940,8 @@ namespace ctranslate2 {
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
return_attention,
return_logits_vocab);
} else {
non_finished_index.emplace_back(i);
sample_from.at<int32_t>(i) = word_id;
Expand Down

0 comments on commit 1fca6c9

Please sign in to comment.