Skip to content

Commit

Permalink
[sharding_in_types] Make slice and ellipsis work with `.at[...].get(o…
Browse files Browse the repository at this point in the history
…ut_sharding=P(...))`

PiperOrigin-RevId: 729723470
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 22, 2025
1 parent 629426f commit 80f18de
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
15 changes: 6 additions & 9 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
from jax._src.numpy import indexing
from jax._src.numpy import lax_numpy
from jax._src.numpy import tensor_contractions
from jax._src import mesh as mesh_lib
from jax._src.pjit import auto_axes, PartitionSpec
from jax._src.pjit import PartitionSpec
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
Expand Down Expand Up @@ -778,16 +777,14 @@ def get(self, *, indices_are_sorted=False, unique_indices=False,
See :mod:`jax.ops` for details.
"""
take = partial(indexing.rewriting_take,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
if out_sharding is not None:
assert isinstance(out_sharding, (NamedSharding, PartitionSpec))
out_sharding = canonicalize_sharding(out_sharding, '.get')
take = auto_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names,
out_shardings=out_sharding.spec)
return take(self.array, self.index)
return indexing.rewriting_take(self.array, self.index,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value,
out_sharding=out_sharding)

def set(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
Expand Down
25 changes: 17 additions & 8 deletions jax/_src/numpy/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from jax._src.api import jit
from jax._src.lax import lax as lax_internal
from jax._src.numpy import einsum
from jax._src import mesh as mesh_lib
from jax._src.pjit import auto_axes
from jax._src.numpy import lax_numpy
from jax._src.numpy import ufuncs
from jax._src.numpy import util
Expand Down Expand Up @@ -600,7 +602,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->


def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
mode=None, fill_value=None, out_sharding=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
Expand All @@ -624,13 +626,13 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,

treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value)
unique_indices, mode, fill_value, out_sharding)

# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value):
unique_indices, mode, fill_value, out_sharding):
idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update
y = arr
Expand All @@ -653,11 +655,18 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,

# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
y = lax.gather(
y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
mode=mode, fill_value=fill_value)
internal_gather = partial(
lax.gather,
dimension_numbers=indexer.dnums,
slice_sizes=indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
mode=mode, fill_value=fill_value)
if out_sharding is not None:
internal_gather = auto_axes(
internal_gather, axes=mesh_lib.get_abstract_mesh().axis_names,
out_shardings=out_sharding)
y = internal_gather(y, indexer.gather_indices)

# Reverses axes with negative strides.
if indexer.reversed_y_dims:
Expand Down
8 changes: 8 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6234,6 +6234,14 @@ def f(embed_vd, token_bt):
out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None))
self.assertEqual(out.shape, (8, 4, 16))
self.assertEqual(out.aval.sharding.spec, P('x', None, None))

out2 = embed_vd.at[token_bt, :].get(out_sharding=P('x', None, None))
self.assertEqual(out2.shape, (8, 4, 16))
self.assertEqual(out2.aval.sharding.spec, P('x', None, None))

out3 = embed_vd.at[token_bt, ...].get(out_sharding=P('x', None, None))
self.assertEqual(out3.shape, (8, 4, 16))
self.assertEqual(out3.aval.sharding.spec, P('x', None, None))
return out

out = f(embed, tok)
Expand Down

0 comments on commit 80f18de

Please sign in to comment.