Skip to content

Commit

Permalink
Adds a grad partial auto test.
Browse files Browse the repository at this point in the history
  • Loading branch information
yliu120 authored Feb 19, 2025
1 parent e4b9fdb commit bbd09e8
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,24 @@ def f(x):
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v*v, f(v), check_dtypes=False)

def test_grad_partial_auto(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))

def h(x):
return x ** 2

@jax.jit
def f(x):
return shard_map(h, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x).sum()

v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(2*v, jax.grad(f)(v), check_dtypes=False)

def test_grad_nested_partial_auto(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))

Expand Down

0 comments on commit bbd09e8

Please sign in to comment.