From bbd09e81e6f528fe6091e035cf0f8cbe11b4890d Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Tue, 18 Feb 2025 23:03:24 -0800 Subject: [PATCH] Adds a grad partial auto test. --- tests/shard_map_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0b6a3c1d0e7a..fbda84deaf84 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2063,6 +2063,24 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_grad_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + + def h(x): + return x ** 2 + + @jax.jit + def f(x): + return shard_map(h, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(2*v, jax.grad(f)(v), check_dtypes=False) + def test_grad_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j'))