Skip to content

Commit

Permalink
pass batch_dim_idx to deepspeed sequence parallel distributed attenti…
Browse files Browse the repository at this point in the history
…on for supporting batch size larger than 1
  • Loading branch information
Jinghan Yao committed Aug 3, 2024
1 parent 1bfc35c commit 309d3f0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,12 +817,14 @@ def forward(self, hidden_states, attention_mask,
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

if self.enable_ds_sequence_parallel:
batch_dim_idx = 1
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
batch_dim_idx = 0

context_layer = self.dist_attn(query_layer, key_layer, value_layer)
context_layer = self.dist_attn(query_layer, key_layer, value_layer, batch_dim_idx)

if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
Expand Down

0 comments on commit 309d3f0

Please sign in to comment.