diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 9a81ae18f39a..713def2b8b1a 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -40,7 +40,7 @@ def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = attn_weight.masked_fill(attn_mask, torch.finfo(attn_weight.dtype).min) if ab is not None: - attn_weight = attn_weight + ab + attn_weight = attn_weight + ab attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output @@ -139,7 +139,7 @@ def test_flash_attention_wrapper_segment_ids_spmd(self): partition_spec=("data", None, None, None)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), - f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") + f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") jax_q = jnp.array(q.numpy(), dtype=jnp.float32) jax_k = jnp.array(k.numpy(), dtype=jnp.float32) @@ -175,12 +175,19 @@ def test_flash_attention_backward_segment_ids_spmd(self): k.retain_grad() v.retain_grad() - o = flash_attention(q, k, v, False, segment_ids, segment_ids, partition_spec=("data", None, None, None)) + o = flash_attention( + q, + k, + v, + False, + segment_ids, + segment_ids, + partition_spec=("data", None, None, None)) loss = o.sum() loss.backward() q_grad = q.grad k_grad = k.grad - v_grad = v.grad + v_grad = v.grad self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") @@ -192,10 +199,9 @@ def test_flash_attention_backward_segment_ids_spmd(self): f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(v_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") torch_xla.sync() - torch.manual_seed(42) q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") @@ -220,6 +226,7 @@ def test_flash_attention_backward_segment_ids_spmd(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update("jax_default_matmul_precision", "default") + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) torch.set_default_dtype(torch.float32)