diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6fd171ea6dab..5e30ffba26a6 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor kv_segment_ids = xs.enable_manual_sharding( kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor - segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( q_segment_ids, kv_segment_ids) ctx.segment_ids = segment_ids @@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) if not save_residuals: @@ -329,15 +329,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided # but it should be OK as the backward will use the same partition_spec - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, - kv_segment_ids, full_ab) + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa, + kv_segment_ids_fa, full_ab) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors causal = ctx.causal sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec @@ -409,7 +409,7 @@ def backward(ctx, grad_output): if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] args += [expanded_l, expanded_m, grad_output, expanded_grad_i] outputs = [q]