Skip to content

Commit

Permalink
Implement Jax CPU/GPU callbacks with XLA's FFI.
Browse files Browse the repository at this point in the history
- Change 4 of 4 addressing #3 in #25842.

PiperOrigin-RevId: 729641154
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 21, 2025
1 parent b3fcba7 commit 6e83de5
Showing 1 changed file with 96 additions and 48 deletions.
144 changes: 96 additions & 48 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.interpreters import xla
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding
Expand Down Expand Up @@ -200,7 +201,11 @@ def _callback_op_sharding(
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
if config.use_shardy_partitioner.value:
op_sharding = sharding_impls.SdyArrayShardingList([
# For shardy, we need to have the same number of shardy annotations as the
# number of result ops. If there are no result ops, we need 1 shardy
# annotation.
num_sdy_shardings = max(1, len(avals_out))
op_sharding = sharding_impls.SdyArrayShardingList(num_sdy_shardings * [
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[],
Expand Down Expand Up @@ -822,55 +827,98 @@ def _wrapped_callback(*args):
for result_aval in result_avals]
return outputs, token, None

result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
if token:
if xla_extension_version <= 316:
result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
if token:

callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))

operand_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
]
result_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
]
operands = [token, *operands]
result_types = [mlir.token_type(), *result_types]
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
callback_descriptor, ifrt_callback = (
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
ctx.module_context.add_host_callback(ifrt_callback)
descriptor_operand = mlir.ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
result_type = ir.TupleType.get_tuple(result_types)
call_target_name = ("xla_python_gpu_callback"
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
result = hlo.CustomCallOp(
[result_type],
callback_operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(str(callback_descriptor)),
operand_layouts=(
None if operand_mlir_layouts is None
else ir.ArrayAttr.get(operand_mlir_layouts)),
result_layouts=(
None if result_mlir_layouts is None
else ir.ArrayAttr.get(result_mlir_layouts)))
if sharding is not None:
mlir.set_sharding(result, sharding)
results = [
hlo.get_tuple_element(result, mlir.i32_attr(i))
for i in range(len(result_types))
]
else:
call_target_name = (
"xla_ffi_python_gpu_callback"
if platform in {"cuda", "rocm"}
else "xla_ffi_python_cpu_callback")
if token:
callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))
operands = [token, *operands]
if (
config.use_shardy_partitioner.value
and sharding is not None
and len(ctx.avals_out) > 0
and isinstance(sharding, sharding_impls.SdyArrayShardingList)
):
# Add a sharding annotation for the token if we have at least one
# output. Otherwise, the single shardy annotation required of all ops
# (even those without any results) can annotate the token.
sharding = sharding_impls.SdyArrayShardingList(
[*sharding.shardings, sharding.shardings[-1]]
)
ctx = dataclasses.replace(
ctx,
avals_in=[core.abstract_token, *ctx.avals_in],
avals_out=[core.abstract_token, *ctx.avals_out],
)

callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))
ifrt_callback = backend.get_emit_python_callback(
_wrapped_callback
)
ctx.module_context.add_host_callback(ifrt_callback)
index = np.uint64(len(ctx.module_context.host_callbacks) - 1)
result = ffi.build_ffi_lowering_function( # type: ignore
call_target_name,
has_side_effect=has_side_effect,
)(ctx, *operands, index=np.uint64(index))

operand_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
]
result_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
]
operands = [token, *operands]
result_types = [mlir.token_type(), *result_types]
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
callback_descriptor, ifrt_callback = (
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
ctx.module_context.add_host_callback(ifrt_callback)
descriptor_operand = mlir.ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
result_type = ir.TupleType.get_tuple(result_types)
call_target_name = ("xla_python_gpu_callback"
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
result = hlo.CustomCallOp(
[result_type],
callback_operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(str(callback_descriptor)),
operand_layouts=(
None if operand_mlir_layouts is None
else ir.ArrayAttr.get(operand_mlir_layouts)),
result_layouts=(
None if result_mlir_layouts is None
else ir.ArrayAttr.get(result_mlir_layouts)))
if sharding is not None:
mlir.set_sharding(result, sharding)
results = [
hlo.get_tuple_element(result, mlir.i32_attr(i))
for i in range(len(result_types))
]
if sharding is not None:
mlir.set_sharding(result, sharding)

results = result.results # type: ignore
if token:
token, *results = results
return results, token, ifrt_callback

0 comments on commit 6e83de5

Please sign in to comment.