diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 13fe70c54129..88874f2b949a 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence -import functools +from functools import partial import importlib.util import logging import string @@ -45,7 +45,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding -from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding +from jax._src.sharding_impls import ( + NamedSharding, PartitionSpec as P, parse_flatten_op_sharding) from jax._src.state import discharge as state_discharge logger = logging.getLogger(__name__) @@ -114,7 +115,7 @@ def get_arg_at_dim(i, dim, arg): return lax.index_in_dim(arg, i, axis=dim, keepdims=False) outs = [] for i in range(axis_size): - args_idx = map(functools.partial(get_arg_at_dim, i), dims, args) + args_idx = map(partial(get_arg_at_dim, i), dims, args) outs.append(debug_callback_p.bind(*args_idx, **params)) outs = [jnp.stack(xs) for xs in zip(*outs)] return outs, (0,) * len(outs) @@ -130,38 +131,55 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], raise ValueError("Transpose doesn't support debugging callbacks.") ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule -def debug_callback_lowering(ctx, *args, effect, callback, **params): +def _debug_callback_partial_auto(axis_context, *args, **params): + from jax.experimental.shard_map import shard_map + partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes) + def f(): + idx = jax.lax.with_sharding_constraint( + jax.lax.axis_index(*partial_auto), + NamedSharding(axis_context.mesh, P())) + return jax.lax.cond(idx == 0, + lambda: debug_callback_p.bind(*args, **params), + lambda: []) + return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])() +def debug_callback_lowering(ctx, *args, effect, callback, **params): axis_context = ctx.module_context.axis_context - if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and - set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): - if config.use_shardy_partitioner.value: - assert len(ctx.avals_out) == 1 - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( - mesh_shape=(), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) - ] * ctx.avals_out[0].ndim, - logical_device_ids=())]) - else: + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + # We're a shard_map, which might be partial-manual or full-manual. + partial_auto = set(axis_context.mesh.axis_names) - axis_context.manual_axes + if partial_auto: + # If we have partial manual / partial auto sharding, we gather and + # conditionally run the callback. + lower = partial(_debug_callback_partial_auto, axis_context, + effect=effect, callback=callback, **params) + return mlir.lower_fun(lower)(ctx, *args) + elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names): # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. - sharding = xc.OpSharding() - sharding.type = xc.OpSharding.Type.MANUAL - elif isinstance( - axis_context, - (sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext), - ): + if config.use_shardy_partitioner.value: + assert len(ctx.avals_out) == 1 + sharding = sharding_impls.SdyArrayShardingList([ + sharding_impls.SdyArraySharding( + mesh_shape=(), + dimension_shardings=[ + sharding_impls.SdyDimSharding(axes=[], is_closed=True) + ] * ctx.avals_out[0].ndim, + logical_device_ids=())]) + else: + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + else: + assert False # Unreachable + elif isinstance(axis_context, sharding_impls.ShardingContext): + # If we have fully automatic sharding during lowering, that means the JAX + # program has bulk array semantics, so we run the callback with a MAXIMAL + # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: sharding = sharding_impls.SdyArrayShardingList([ sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) else: - # If we have fully automatic sharding during lowering, that means the JAX - # program has bulk array semantics, so we run the callback with a MAXIMAL - # sharding and hence execute it only once on the full logical value). - # If we have partially automatic sharding, we do this too... not sure why! sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MAXIMAL sharding.tile_assignment_dimensions = [1] @@ -354,7 +372,7 @@ def debug_print(fmt: str, *args, **kwargs): # Check that we provide the correct arguments to be formatted. formatter.format(fmt, *args, **kwargs) - debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()), + debug_callback(partial(_format_print_callback, fmt, np.get_printoptions()), *args, **kwargs, ordered=ordered) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index a2b3b9de5a9d..c4923e5c5a01 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2204,6 +2204,24 @@ def f(x): # # f(x) # don't crash + def test_partial_auto_debug_print(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("shardy error") + + mesh = jtu.create_mesh((4, 2), ('i', 'j')) + x = jnp.arange(8.) + + def g(x): + jax.debug.print('{}', x) + + @jax.jit + def f(x): + return shard_map(g, + mesh, in_specs=P('i'), out_specs=None, + check_rep=False, auto=frozenset({'j'}))(x) + + y = f(x) # don't crash + def test_partial_auto_of_random_keys(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) keys = jax.random.split(jax.random.key(0), 8)