From c00bf53118812b3e0fe4272da831aaaec8f4b138 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:06:16 -0800 Subject: [PATCH] Support SegmentID when doing data prallel SPMD (#8425) --- test/test_pallas_spmd.py | 130 +++++++++++++++++++++++- torch_xla/experimental/custom_kernel.py | 23 +++-- 2 files changed, 146 insertions(+), 7 deletions(-) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index e88b8b2caff1..713def2b8b1a 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -3,6 +3,7 @@ import unittest import torch +import numpy as np from torch import nn as nn import torch_xla @@ -22,8 +23,24 @@ class PallasTest(unittest.TestCase): - def _attention(self, q, k, v): + # This is to create a diagonal mask where only elements within the same segment + # can attend to each other. Since the mask is to mask out the unrelevant parts, + # therefore we use != instead of ==. + def _make_attention_mask_from_segment_ids(self, q_segment_ids, + kv_segment_ids): + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) + + def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) + if attn_mask is not None: + # Masked out the unrelevant parts. + attn_weight = attn_weight.masked_fill(attn_mask, + torch.finfo(attn_weight.dtype).min) + if ab is not None: + attn_weight = attn_weight + ab attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output @@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', "default") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_spmd(self): + from torch_xla.experimental.custom_kernel import flash_attention + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds + xs.set_global_mesh(xs.get_1d_mesh("data")) + + q = torch.randn(3, 2, 128, 4) + k = torch.randn(3, 2, 128, 4) + v = torch.randn(3, 2, 128, 4) + zeros = torch.zeros(3, 32) + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + segment_ids_xla = segment_ids.to("xla") + # only shard data dimension + o = flash_attention( + q.to("xla"), + k.to("xla"), + v.to("xla"), + False, + segment_ids_xla, + segment_ids.to("xla"), + partition_spec=("data", None, None, None)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") + + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) + jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32) + expected_o = torch.from_numpy( + np.array( + jax_flash_attention( + jax_q, + jax_k, + jax_v, + segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids), + ))) + + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', "default") + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward_segment_ids_spmd(self): + jax.config.update("jax_default_matmul_precision", "highest") + from torch_xla.experimental.custom_kernel import flash_attention + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.get_1d_mesh("data")) + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention( + q, + k, + v, + False, + segment_ids, + segment_ids, + partition_spec=("data", None, None, None)) + loss = o.sum() + loss.backward() + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(q_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(k_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(v_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + torch_xla.sync() + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + segment_ids, segment_ids)) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update("jax_default_matmul_precision", "default") + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index fdc5992c3b05..5e30ffba26a6 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -266,7 +266,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, dtypes.append(torch.float32) with torch.no_grad(): - segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None: + # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id + # is of shape [batch, seq_len], hence we need to tweak it a bit + segment_id_partition_spec = (partition_spec[0], partition_spec[2]) + q_segment_ids = xs.enable_manual_sharding( + q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + kv_segment_ids = xs.enable_manual_sharding( + kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( q_segment_ids, kv_segment_ids) ctx.segment_ids = segment_ids @@ -297,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) if not save_residuals: @@ -319,20 +327,23 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, m = xs.disable_manual_sharding( m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, - kv_segment_ids, full_ab) + # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided + # but it should be OK as the backward will use the same partition_spec + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa, + kv_segment_ids_fa, full_ab) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors causal = ctx.causal sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh full_shape = ctx.full_shape + # this segment_ids only reflects the local shape of segment_ids segment_ids = ctx.segment_ids grad_q = grad_k = grad_v = grad_ab = None @@ -398,7 +409,7 @@ def backward(ctx, grad_output): if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] args += [expanded_l, expanded_m, grad_output, expanded_grad_i] outputs = [q]