Skip to content

Commit

Permalink
Merge pull request #25887 from mattjj:partial-auto-something
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 727110625
  • Loading branch information
Google-ML-Automation committed Feb 15, 2025
2 parents d3850e7 + 3681960 commit df135d2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
72 changes: 45 additions & 27 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
import functools
from functools import partial
import importlib.util
import logging
import string
Expand Down Expand Up @@ -45,7 +45,8 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
from jax._src.sharding_impls import (
NamedSharding, PartitionSpec as P, parse_flatten_op_sharding)
from jax._src.state import discharge as state_discharge

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,7 +115,7 @@ def get_arg_at_dim(i, dim, arg):
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
outs = []
for i in range(axis_size):
args_idx = map(functools.partial(get_arg_at_dim, i), dims, args)
args_idx = map(partial(get_arg_at_dim, i), dims, args)
outs.append(debug_callback_p.bind(*args_idx, **params))
outs = [jnp.stack(xs) for xs in zip(*outs)]
return outs, (0,) * len(outs)
Expand All @@ -130,38 +131,55 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
raise ValueError("Transpose doesn't support debugging callbacks.")
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule

def debug_callback_lowering(ctx, *args, effect, callback, **params):
def _debug_callback_partial_auto(axis_context, *args, **params):
from jax.experimental.shard_map import shard_map
partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes)
def f():
idx = jax.lax.with_sharding_constraint(
jax.lax.axis_index(*partial_auto),
NamedSharding(axis_context.mesh, P()))
return jax.lax.cond(idx == 0,
lambda: debug_callback_p.bind(*args, **params),
lambda: [])
return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])()

def debug_callback_lowering(ctx, *args, effect, callback, **params):
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
if config.use_shardy_partitioner.value:
assert len(ctx.avals_out) == 1
sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
] * ctx.avals_out[0].ndim,
logical_device_ids=())])
else:
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
# We're a shard_map, which might be partial-manual or full-manual.
partial_auto = set(axis_context.mesh.axis_names) - axis_context.manual_axes
if partial_auto:
# If we have partial manual / partial auto sharding, we gather and
# conditionally run the callback.
lower = partial(_debug_callback_partial_auto, axis_context,
effect=effect, callback=callback, **params)
return mlir.lower_fun(lower)(ctx, *args)
elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names):
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
elif isinstance(
axis_context,
(sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext),
):
if config.use_shardy_partitioner.value:
assert len(ctx.avals_out) == 1
sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
] * ctx.avals_out[0].ndim,
logical_device_ids=())])
else:
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
else:
assert False # Unreachable
elif isinstance(axis_context, sharding_impls.ShardingContext):
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
if config.use_shardy_partitioner.value:
sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))])
else:
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
# If we have partially automatic sharding, we do this too... not sure why!
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
Expand Down Expand Up @@ -354,7 +372,7 @@ def debug_print(fmt: str, *args, **kwargs):
# Check that we provide the correct arguments to be formatted.
formatter.format(fmt, *args, **kwargs)

debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
debug_callback(partial(_format_print_callback, fmt, np.get_printoptions()),
*args, **kwargs, ordered=ordered)


Expand Down
18 changes: 18 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,6 +2204,24 @@ def f(x):
#
# f(x) # don't crash

def test_partial_auto_debug_print(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("shardy error")

mesh = jtu.create_mesh((4, 2), ('i', 'j'))
x = jnp.arange(8.)

def g(x):
jax.debug.print('{}', x)

@jax.jit
def f(x):
return shard_map(g,
mesh, in_specs=P('i'), out_specs=None,
check_rep=False, auto=frozenset({'j'}))(x)

y = f(x) # don't crash

def test_partial_auto_of_random_keys(self):
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
keys = jax.random.split(jax.random.key(0), 8)
Expand Down

0 comments on commit df135d2

Please sign in to comment.