Skip to content

Commit

Permalink
fix loss performance (#1622)
Browse files Browse the repository at this point in the history
Co-authored-by: thucpham <minhthuc.pham@systrangroup.com>
  • Loading branch information
minhthuc2502 and thucpham authored Feb 15, 2024
1 parent 8e82733 commit a4d7820
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ namespace ctranslate2 {
}

for (size_t i = 0; i < layer_ins.size(); ++i) {
auto layer_in_chunk = layer_ins[i];
StorageView* layer_in_chunk = &layer_ins[i];
for (size_t l = 0; l < _layers.size(); ++l) {
StorageView* cached_self_attn_keys = nullptr;
StorageView* cached_self_attn_values = nullptr;
Expand All @@ -583,22 +583,22 @@ namespace ctranslate2 {
dim_t offset = _sliding_window * i + step;
offset = offset < 0 ? 0 : offset;
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
auto max_tokens = _sliding_window + layer_in_chunk->dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk->dim(0)}, int32_t(max_tokens), device);
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
tmp_lengths,
_num_heads,
max_tokens,
/*mask_future=*/true,
multi_query);

const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1));
const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk->dim(1));
// reuse tmp_lengths
slide_lengths_op(lengths_mask, tmp_lengths);
input_lengths_mask = std::make_unique<StorageView>(std::move(tmp_lengths));
}

(*_layers[l])(layer_in_chunk,
(*_layers[l])(*layer_in_chunk,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
Expand All @@ -613,14 +613,14 @@ namespace ctranslate2 {
return_normalized_attention(),
&position_bias,
offset);
layer_in_chunk = std::move(layer_out);
*layer_in_chunk = std::move(layer_out);

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
}
}
layer_in = std::move(layer_in_chunk);
layer_in = std::move(*layer_in_chunk);
}

if (step == 0) {
Expand Down

0 comments on commit a4d7820

Please sign in to comment.