Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SegmentID when doing data prallel SPMD #8425

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
import numpy as np
from torch import nn as nn

import torch_xla
Expand All @@ -22,8 +23,24 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_spmd(self):
from torch_xla.experimental.custom_kernel import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
xs.set_global_mesh(xs.get_1d_mesh("data"))

q = torch.randn(3, 2, 128, 4)
k = torch.randn(3, 2, 128, 4)
v = torch.randn(3, 2, 128, 4)
zeros = torch.zeros(3, 32)
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
segment_ids_xla = segment_ids.to("xla")
# only shard data dimension
o = flash_attention(
q.to("xla"),
k.to("xla"),
v.to("xla"),
False,
segment_ids_xla,
segment_ids.to("xla"),
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32)
expected_o = torch.from_numpy(
np.array(
jax_flash_attention(
jax_q,
jax_k,
jax_v,
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
)))

self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.get_1d_mesh("data"))

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(
q,
k,
v,
False,
segment_ids,
segment_ids,
partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(
q,
k,
v,
attn_mask=self._make_attention_mask_from_segment_ids(
segment_ids, segment_ids))
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
23 changes: 17 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
dtypes.append(torch.float32)

with torch.no_grad():
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
# partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
# is of shape [batch, seq_len], hence we need to tweak it a bit
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
ctx.segment_ids = segment_ids

Expand Down Expand Up @@ -297,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

if not save_residuals:
Expand All @@ -319,20 +327,23 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
kv_segment_ids_fa, full_ab)
return o

@staticmethod
def backward(ctx, grad_output):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
# this segment_ids only reflects the local shape of segment_ids
segment_ids = ctx.segment_ids
grad_q = grad_k = grad_v = grad_ab = None

Expand Down Expand Up @@ -398,7 +409,7 @@ def backward(ctx, grad_output):
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

outputs = [q]
Expand Down