Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Nov 27, 2024
1 parent 0453dd1 commit 5e2cb30
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_flash_attention_wrapper_segment_ids_spmd(self):
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
Expand Down Expand Up @@ -175,12 +175,19 @@ def test_flash_attention_backward_segment_ids_spmd(self):
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v, False, segment_ids, segment_ids, partition_spec=("data", None, None, None))
o = flash_attention(
q,
k,
v,
False,
segment_ids,
segment_ids,
partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
Expand All @@ -192,10 +199,9 @@ def test_flash_attention_backward_segment_ids_spmd(self):
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()


torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
Expand All @@ -220,6 +226,7 @@ def test_flash_attention_backward_segment_ids_spmd(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
Expand Down

0 comments on commit 5e2cb30

Please sign in to comment.