Skip to content

Commit

Permalink
rename segment_id for pallas kernel to make code less confusing
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Nov 28, 2024
1 parent 5e2cb30 commit c4db598
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
ctx.segment_ids = segment_ids

Expand Down Expand Up @@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

if not save_residuals:
Expand All @@ -329,15 +329,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,

# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
kv_segment_ids_fa, full_ab)
return o

@staticmethod
def backward(ctx, grad_output):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
Expand Down Expand Up @@ -409,7 +409,7 @@ def backward(ctx, grad_output):
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

outputs = [q]
Expand Down

0 comments on commit c4db598

Please sign in to comment.