diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index d5c7fcc16..056a01f99 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -558,7 +558,7 @@ namespace ctranslate2 { } for (size_t i = 0; i < layer_ins.size(); ++i) { - auto layer_in_chunk = layer_ins[i]; + StorageView* layer_in_chunk = &layer_ins[i]; for (size_t l = 0; l < _layers.size(); ++l) { StorageView* cached_self_attn_keys = nullptr; StorageView* cached_self_attn_values = nullptr; @@ -583,8 +583,8 @@ namespace ctranslate2 { dim_t offset = _sliding_window * i + step; offset = offset < 0 ? 0 : offset; if (i > 0) { - auto max_tokens = _sliding_window + layer_in_chunk.dim(1); - StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device); + auto max_tokens = _sliding_window + layer_in_chunk->dim(1); + StorageView tmp_lengths = StorageView(Shape{layer_in_chunk->dim(0)}, int32_t(max_tokens), device); StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( tmp_lengths, _num_heads, @@ -592,13 +592,13 @@ namespace ctranslate2 { /*mask_future=*/true, multi_query); - const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1)); + const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk->dim(1)); // reuse tmp_lengths slide_lengths_op(lengths_mask, tmp_lengths); input_lengths_mask = std::make_unique(std::move(tmp_lengths)); } - (*_layers[l])(layer_in_chunk, + (*_layers[l])(*layer_in_chunk, input_lengths_mask.get(), memory, memory_lengths_mask.get(), @@ -613,14 +613,14 @@ namespace ctranslate2 { return_normalized_attention(), &position_bias, offset); - layer_in_chunk = std::move(layer_out); + *layer_in_chunk = std::move(layer_out); if (layer_attention) { alignment_heads.emplace_back(dtype, device); ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back()); } } - layer_in = std::move(layer_in_chunk); + layer_in = std::move(*layer_in_chunk); } if (step == 0) {