diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py new file mode 100644 index 000000000000..392c38387946 --- /dev/null +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -0,0 +1,699 @@ +from dataclasses import dataclass +import json +import operator +from functools import partial, reduce +from typing import List + +# Third-party imports +import jax +import jax.numpy as jnp +import numpy as np +from jax import custom_vjp, lax +from jax._src import core, dispatch, dtypes +from jax._src.custom_partitioning import custom_partitioning +from jax._src.interpreters import batching +from jax._src.lax.lax import ranges_like, remaining +from jax._src.typing import DTypeLike +from jax.interpreters import mlir, xla +from jax.interpreters.mlir import ir +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + + +Array = jnp.ndarray +block_scaled_dot_name = "__op$block_scaled_dot" + +@dataclass +class BlockScaleConfig: + mode: str + block_size: int + data_type: DTypeLike + scale_type: DTypeLike + global_scale: Array | None + infer_only: bool + +def default_layouts(*shapes): + return [range(len(shape) - 1, -1, -1) for shape in shapes] + +def element_type_to_backend_config_type(dtype): + _element_type_to_backend_config_type_mapping = { + ir.BF16Type.get(): "BF16", + ir.F16Type.get(): "F16", + ir.F32Type.get(): "F32", + } + return _element_type_to_backend_config_type_mapping[dtype] + + +def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type): + return _scaled_matmul_p.bind( + a, b, a_scale, b_scale, preferred_element_type=preferred_element_type + ) + + +def _scaled_matmul_cuda_lowering( + ctx, a, b, a_scales, b_scales, preferred_element_type + ): + lhs_type = ir.RankedTensorType(a.type) + lhs_shape = lhs_type.shape + rhs_type = ir.RankedTensorType(b.type) + rhs_shape = rhs_type.shape + + batch, non_contracting_lhs, contracting = lhs_shape + _, non_contracting_rhs, _ = rhs_shape + result_shape = (batch, non_contracting_lhs, non_contracting_rhs) + + out_type = mlir.dtype_to_ir_type(preferred_element_type) + result_types = [ir.RankedTensorType.get(result_shape, out_type)] + + operands = [a, b, a_scales, b_scales] + backend_config = { + "scaled_dot_backend_config": { + "lhs_batch_dimensions": [0], + "rhs_batch_dimensions": [0], + "dequantize_type": element_type_to_backend_config_type(out_type), + } + } + + backend_config = json.dumps(backend_config) + out = mlir.custom_call( + block_scaled_dot_name, + result_types=result_types, + operands=operands, + backend_config=backend_config, + operand_layouts=default_layouts( + *[ir.RankedTensorType(operand.type).shape for operand in operands] + ), + result_layouts=default_layouts(result_shape), + ) + return [out.result] + + +def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): + a_dtype = dtypes.canonicalize_dtype(a.dtype) + batch, non_contracting_lhs, contracting_lhs = a.shape + _, non_contracting_rhs, _ = b.shape + output_shape = (batch, non_contracting_lhs, non_contracting_rhs) + return (core.ShapedArray(output_shape, preferred_element_type),) + + +_scaled_matmul_p = core.Primitive("scaled_matmul") +_scaled_matmul_p.multiple_results = True +_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p)) +_scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract) + + +mlir.register_lowering( + _scaled_matmul_p, + _scaled_matmul_cuda_lowering, + platform="cuda", +) + +_scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper") +_scaled_matmul_p_wrapper.multiple_results = True +_scaled_matmul_p_wrapper.def_impl(_scaled_matmul_impl) +_scaled_matmul_p_wrapper.def_abstract_eval(_scaled_matmul_abstract) + +# Given the inputs already sharded as +# ([B], M, K1), ([B], N, K2) +# We define the following rule to apply necessary AllGather based on +# "Input specs", and to define the "Output spec". +# 1. If K1 == K2 != None and N == None: +# - Input spec : ([B], M, K1), ([B], None, K2) +# - Output spec: ([B], M, None) -> AllReduce -> ([B], M, None) +# 2. If K1 == K2 != None and M == N != None: +# - Input spec : ([B], M, K1), ([B], None, K2) +# - Output spec: ([B], M, None) -> ReduceScatter -> ([B], M, N) +# 3. If N == M: +# - Input specs : ([B], M, None), ([B], None, None) +# - Output specs: ([B], M, None) +# 4. If N != M: +# - Input spec : ([B], M, None), ([B], N, None) +# - Output spec: ([B], M, N) +def _check_shardings(shardings): + if len(shardings) != 4: + msg = f"shardings should container 4 inputs, but got {len(shardings)}" + raise TypeError(msg) + lhs, rhs, _, _ = shardings + if len(lhs.spec) != 3 or len(rhs.spec) != 3: + msg = (f'shardings specs rank should be 3, but got lhs: {len(lhs.spec)} ' + 'and rhs: {len(rhs.spec)}') + raise TypeError(msg) + if lhs.spec[0] != rhs.spec[0]: + msg = ('shardings spec for batch dim should be same, but got lhs: ' + '{lhs.spec[0]} and rhs: {rhs.spec[0]}') + raise TypeError(msg) + + +def _enable_reduce_scatter(lhs, rhs): + batch_spec, m_spec, lhs_k_spec = lhs.spec + _, n_spec, rhs_k_spec = rhs.spec + return ( + lhs_k_spec != None + and lhs_k_spec == rhs_k_spec + and m_spec != None + and m_spec == n_spec + ) + + +def _enable_all_reduce(lhs, rhs): + batch_spec, m_spec, lhs_k_spec = lhs.spec + _, n_spec, rhs_k_spec = rhs.spec + return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None + + +def _get_output_sharding(mesh, shardings): + lhs, rhs = shardings[0], shardings[1] + batch_spec, m_spec, _ = lhs.spec + _, n_spec, _ = rhs.spec + + if _enable_reduce_scatter(lhs, rhs): + return [NamedSharding(lhs.mesh, P(*lhs.spec))] + + output_specs = (batch_spec, m_spec) + output_specs += (n_spec,) if m_spec != n_spec else (None,) + return [NamedSharding(lhs.mesh, P(*output_specs))] + + +def _scaled_matmul_infer_sharding_from_operands( + preferred_element_type, mesh, shapes, output_shape + ): + shardings = jax.tree.map(lambda x: x.sharding, shapes) + _check_shardings(shardings) + + return _get_output_sharding(mesh, shardings) + + +def supported_in_sharding(mesh, shardings): + lhs_sharding, rhs_sharding = shardings[0], shardings[1] + use_reduce_scatter = _enable_reduce_scatter(lhs_sharding, rhs_sharding) + use_all_reduce = _enable_all_reduce(lhs_sharding, rhs_sharding) + assert not (use_all_reduce and use_reduce_scatter) + + lhs_specs, rhs_specs = list(lhs_sharding.spec), list(rhs_sharding.spec) + + def named_sharding(lhs, rhs, lhs_specs, rhs_specs): + lhs_sharding = NamedSharding(lhs.mesh, P(*lhs_specs)) + rhs_sharding = NamedSharding(rhs.mesh, P(*rhs_specs)) + return (lhs_sharding, rhs_sharding, lhs_sharding, rhs_sharding) + + if use_all_reduce: + return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) + + if use_reduce_scatter: + rhs_specs[1] = None + return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) + + lhs_specs[2] = None + rhs_specs[2] = None + m_spec, n_spec = lhs_specs[1], rhs_specs[1] + if m_spec == n_spec: + rhs_specs[1] = None + + return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) + + +def _scaled_matmul_partition( + preferred_element_type, mesh, shapes, output_shape + ): + shardings = jax.tree.map(lambda x: x.sharding, shapes) + _check_shardings(shardings) + + lhs, rhs = shardings[0], shardings[1] + use_all_reduce = _enable_all_reduce(lhs, rhs) + use_reduce_scatter = _enable_reduce_scatter(lhs, rhs) + lhs_k_spec = lhs.spec[2] + + def _scaled_matmul_impl_partition(a, b, a_scale, b_scale): + z = _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type) + if use_reduce_scatter: + z = jax.lax.psum_scatter( + z, lhs_k_spec, scatter_dimension=2, tiled=True + ) + if use_all_reduce: + z = jax.lax.psum(z, lhs_k_spec) + return z + + out_shardings = _get_output_sharding(mesh, shardings) + arg_shardings = supported_in_sharding(mesh, shardings) + return mesh, _scaled_matmul_impl_partition, out_shardings, arg_shardings + + +_scaled_matmul_lower = custom_partitioning( + _scaled_matmul_impl, static_argnums=(4,) +) + +_scaled_matmul_lower.def_partition( + infer_sharding_from_operands=_scaled_matmul_infer_sharding_from_operands, + partition=_scaled_matmul_partition, +) + + +def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type): + assert len(batch_dims) == 4 + assert ( + batch_dims[0] == batch_dims[1] + and batch_dims[0] == batch_dims[2] + and batch_dims[0] == batch_dims[3] + ) + lhs_bdims = batch_dims[0] + out_bdims = (batch_dims[0],) + lhs, rhs, lhs_scales, rhs_scales = batched_args + *batch, lhs_non_contracting, contracting = lhs.shape + *_, _, scales_contracting = lhs_scales.shape + *_, rhs_non_contracting, _ = rhs.shape + + new_batch = reduce(operator.mul, batch) + # reshape to 3D shape + lhs = jnp.reshape(lhs, (new_batch, lhs_non_contracting, contracting)) + lhs_scales = jnp.reshape( + lhs_scales, (new_batch, lhs_non_contracting, scales_contracting) + ) + rhs = jnp.reshape(rhs, (new_batch, rhs_non_contracting, contracting)) + rhs_scales = jnp.reshape( + rhs_scales, (new_batch, rhs_non_contracting, scales_contracting) + ) + output = jnp.reshape( + _scaled_matmul_p_wrapper.bind( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=preferred_element_type, + )[0], + (*batch, lhs_non_contracting, rhs_non_contracting), + ) + return (output,), out_bdims + + +mlir.register_lowering( + _scaled_matmul_p_wrapper, + mlir.lower_fun(_scaled_matmul_lower, multiple_results=True), +) + +dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p) +dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p_wrapper) + +batching.primitive_batchers[_scaled_matmul_p_wrapper] = _scaled_matmul_batcher +batching.primitive_batchers[_scaled_matmul_p] = _scaled_matmul_batcher + + +@partial(jax.jit, static_argnames=("preferred_element_type",)) +def _scaled_matmul( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = jnp.float32, + ) -> Array: + output = _scaled_matmul_p_wrapper.bind( + lhs, rhs, lhs_scales, rhs_scales, + preferred_element_type=preferred_element_type + ) + return output[0] + +def scaled_matmul_wrapper( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = jnp.float32, +) -> Array: + """ + Performs scaled matrix multiplication between two 3D arrays, with scaling + factors applied to the matrices. + + Args: + lhs (Array): A 3D array of shape (B, M, K). + rhs (Array): A 3D array of shape (B, N, K). + lhs_scales (Array): A 3D array of shape (B, M, K_block). + rhs_scales (Array): A 3D array of shape (B, N, K_block). + preferred_element_type (DTypeLike, optional): The preferred data type + for the computation. Defaults to `jnp.float32`. + + Returns: + Array: A 3D array of shape (B, M, N) representing the scaled matrix + multiplication result. + + Raises: + AssertionError: If the number of columns in `lhs` (`lhs_K`) does not + match the number of columns in `rhs` (`rhs_K`). + + Notes: + - The function ensures that the `preferred_element_type` is + danonicalized before passing it to the underlying computation. + - Scaling is applied to the matrices based on the `lhs_scales` and + `rhs_scales` arrays, enabling efficient computations in blocks. + + """ + B, M, lhs_K = lhs.shape + _, N, rhs_K = rhs.shape + assert lhs_K == rhs_K + _, _, K_block = lhs_scales.shape + + preferred_element_type = dtypes.canonicalize_dtype( + np.dtype(preferred_element_type) + ) + out = _scaled_matmul( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=preferred_element_type, + ) + return out + +def shape_normalization(x, dimension_numbers): + """ + Normalizes the shape of the input tensor `x` to `(B, M, K)`. + + This function rearranges and reshapes the input tensor `x` such that: + - `B` represents the batch dimensions. + - `M` represents the non-contracting dimensions. + - `K` represents the contracting dimensions. + + The dimensions are reordered and reshaped based on the provided + `dimension_numbers`. + + Parameters: + x: The input tensor to normalize. + dimension_numbers: A tuple containing two elements: + - `batch_dims` (tuple): The dimensions of `x` to be treated as batch + dimensions. + - `contracting_dims` (tuple): The dimensions of `x` to be treated as + contracting dimensions. + + Returns: + jax.numpy.ndarray: The reshaped tensor with shape `(B, M, K)` + """ + + orig_order = list(range(x.ndim)) + contracting_dims, batch_dims = dimension_numbers + contracting_order = [d for d in orig_order if d in contracting_dims] + batch_order = [d for d in orig_order if d in batch_dims] + non_contracting_order = [ + d + for d in orig_order + if d not in contracting_dims and d not in batch_dims + ] + batch_shape = [x.shape[d] for d in batch_order] + rows_shape = [x.shape[d] for d in non_contracting_order] + cols_shape = [x.shape[d] for d in contracting_order] + new_order = batch_order + non_contracting_order + contracting_order + rows, cols, batches = ( + np.prod(rows_shape), + np.prod(cols_shape), + np.prod(batch_shape, dtype=int), + ) + t = jnp.transpose(x, new_order) + return jnp.reshape(t, (batches, rows, cols)) + + +def compute_dot_output_shape( + lhs_shape, rhs_shape, lhs_dimension_numbers, rhs_dimension_numbers + ): + """ + Computes the output shape for a `lax.dot_general`-like operation. + """ + lhs_contract, lhs_batch = lhs_dimension_numbers[0], lhs_dimension_numbers[1] + rhs_contract, rhs_batch = rhs_dimension_numbers[0], rhs_dimension_numbers[1] + + output_shape = [] + # Add dimensions for batch (assuming the batch dims of LHS and RHS + # should be same) + for i, dim in enumerate(lhs_shape): + if i in lhs_batch: + output_shape.append(dim) + # Add dimensions from the LHS that are non contracting + for i, dim in enumerate(lhs_shape): + if i not in lhs_contract and i not in lhs_batch: + output_shape.append(dim) + # Add dimensions from the RHS that are non contracting + for i, dim in enumerate(rhs_shape): + if i not in rhs_contract and i not in rhs_batch: + output_shape.append(dim) + return tuple(output_shape) + + +def cast_to_e8m0_with_rounding_up(x): + temp = x.astype(jnp.float32).view(jnp.uint32) + exp = temp >> 23 + mant = temp & 0x7FFFFF + is_ru = jnp.logical_and( + jnp.logical_and((mant > 0), (exp != 0xFE)), + ~jnp.logical_and((exp == 0), (mant <= 0x400000)) + ) + exp = jnp.where(is_ru, exp + 1, exp) + new_x = exp.astype(jnp.uint8) + return new_x + + +def e8m0_to_dtype(x, dtype): + temp = x.astype(jnp.uint32) + exp = temp << 23 + new_x = exp.view(jnp.float32) + near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127 + new_x = jnp.where( + new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x + ) + return new_x.astype(dtype) + +def quantize(x, config): + x_shape = x.shape + contract_dim = x_shape[-1] + block_size = config.block_size + assert contract_dim >= block_size and contract_dim % block_size == 0 + x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size) + x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size) + + amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True) + MAX = jnp.finfo(config.data_type).max.astype(x.dtype) + scales = amax / MAX # shape = (B, M, K / block_size, 1) + + if config.mode == "mxfp8": + assert config.scale_type == jnp.float8_e8m0fnu + scales_q = cast_to_e8m0_with_rounding_up(scales) + scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + elif config.mode == "nvfp4": + assert config.scale_type == jnp.float8_e4m3fn + # shuw(TODO): Add when XLA is ready and e2m1fn is available. + scales_q = scales + scales_x = x + else: + raise ValueError(f"Unrecognized mode: {config.mode}.") + + clipped_x = jnp.clip(scaled_x, -MAX, MAX) + x_q = clipped_x.astype(config.data_type) + + x_q = x_q.reshape(x_shape) # shape = (B, M, K) + scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( + config.scale_type + ) + return x_q, scales_q + + +def quantize_to_qtype(x, q_dtype, compute_dtype, scale): + # Explicitly cast the max values to the compute dtype to avoid unnecessary + # casting to FP32 during the subsequent math operations." + assert q_dtype in (jnp.float8_e4m3fn, ) + dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype) + scaled_x = x / jnp.broadcast_to( + jnp.asarray(scale, dtype=compute_dtype), x.shape + ) + clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) + return clipped_x.astype(q_dtype) + +def quantize_dequantize(x, q_dtype, scale, compute_dtype): + qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale) + out = qx.astype(x.dtype) * jnp.broadcast_to( + jnp.asarray(scale, dtype=x.dtype), qx.shape + ) + return out + + + +def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs): + if preferred_element_type is None: + preferred_element_type = dtypes.result_type( + lhs, rhs, return_weak_type_flag=False + ) + else: + preferred_element_type = dtypes.canonicalize_dtype( + np.dtype(preferred_element_type) + ) + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + lhs_dn = (lhs_contract, lhs_batch) + rhs_dn = (rhs_contract, rhs_batch) + + lhs_3d = shape_normalization(lhs, lhs_dn) + rhs_3d = shape_normalization(rhs, rhs_dn) + lhs_config, rhs_config = configs[0], configs[1] + lhs_q, lhs_scales = quantize(lhs_3d, lhs_config) + rhs_q, rhs_scales = quantize(rhs_3d, rhs_config) + + out = scaled_matmul_wrapper( + lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type + ) + + expanded_out_shape = compute_dot_output_shape( + lhs.shape, rhs.shape, lhs_dn, rhs_dn + ) + expanded_out = jnp.reshape(out, expanded_out_shape) + return expanded_out + + +def scaled_dot_general_transpose_lhs( + g, x, y, *, dimension_numbers, preferred_element_type, configs, + swap_ans=False + ): + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + x_ndim = x.aval.ndim + x_kept = remaining(range(x_ndim), x_contract, x_batch) + y_kept = remaining(range(np.ndim(y)), y_contract, y_batch) + if swap_ans: + ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept) + else: + ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) + + dims = ((ans_y, y_kept), (ans_batch, y_batch)) + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + + y_dn = (y_kept, y_batch) + g_dn = (ans_y, ans_batch) + + y_3d = shape_normalization(y, y_dn) + g_3d = shape_normalization(g, g_dn) + + g_config, y_config = configs[0], configs[1] + + g_q, g_scales = quantize(g_3d, g_config) + y_q, y_scales = quantize(y_3d, y_config) + + out = scaled_matmul_wrapper( + g_q, y_q, g_scales, y_scales, preferred_element_type + ) + + expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn) + expanded_out = jnp.reshape(out, expanded_out_shape) + x_bar = lax.transpose(expanded_out, tuple(out_axes)) + return x_bar + + +def scaled_dot_general_transpose_rhs( + g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike, + configs: List[BlockScaleConfig] + ): + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + y_bar = scaled_dot_general_transpose_lhs( + g, + y, + x, + dimension_numbers=swapped_dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs, + swap_ans=True, + ) + return y_bar + + +@partial(custom_vjp, nondiff_argnums=(2, 3, 4)) +def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type, + configs): + return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs) + + +def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type, + configs): + out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs) + res = (lhs, rhs) + return out, res + + +def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): + (lhs, rhs) = res + + args = [g, lhs, rhs] + kw_args = { + "dimension_numbers": dimension_numbers, + "preferred_element_type": preferred_element_type, + } + lhs_kw_args = { + **kw_args, + "configs": [configs[2], configs[1]] + } + rhs_kw_args = { + **kw_args, + "configs": [configs[2], configs[0]] + } + grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) + grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + return (grad_lhs, grad_rhs) + + +scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd) + + +def ensure_tuple(dimension_numbers): + _to_tuple = lambda x: x if isinstance(x, tuple) else tuple(x) + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + lhs_contract = _to_tuple(lhs_contract) + rhs_contract = _to_tuple(rhs_contract) + lhs_batch = _to_tuple(lhs_batch) + rhs_batch = _to_tuple(rhs_batch) + return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) + + +def _ensure_batch_dim(lhs, rhs, dimension_numbers): + contracting_dims, (lhs_batch, rhs_batch) = dimension_numbers + lhs_batched = lhs + rhs_batched = rhs + + if lhs_batch == (): # expand the last dim + lhs_batched = jnp.expand_dims(lhs, axis=lhs.aval.ndim) + lhs_batch = (lhs.aval.ndim,) + if rhs_batch == (): + rhs_batched = jnp.expand_dims(rhs, axis=rhs.aval.ndim) + rhs_batch = (rhs.aval.ndim,) + dn_batched = contracting_dims, (lhs_batch, rhs_batch) + return lhs_batched, rhs_batched, dn_batched + + +def scaled_dot_general_wrapper( + lhs, rhs, dimension_numbers, + preferred_element_type=jnp.float32, + configs: List[BlockScaleConfig] | None=None, + ): + if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16): + msg = ('Only support preferred_element_type in (f32, bf16, f16), but got ' + '{preferred_element_type}') + raise TypeError(msg) + if configs is None: + mxfp8_config = BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + configs = [mxfp8_config, mxfp8_config, mxfp8_config] + + dimension_numbers = ensure_tuple(dimension_numbers) + lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim( + lhs, rhs, dimension_numbers + ) + out = scaled_dot_general_fn( + lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs, + ) + + # Expanding batch dims for operands adds a singleton batch dim at axis 0 in + # the output, which we need to squeeze. + if dn_batched != dimension_numbers: + return jnp.squeeze(out, axis=0) + return out diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 72ac74c38c5d..fb00169b1505 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -21,13 +21,17 @@ import operator import math import numpy as np -from typing import Any, Literal +from typing import Any, List, Literal import jax import jax.numpy as jnp from jax import custom_jvp from jax import lax from jax._src import config +from jax._src.cudnn.scaled_matmul_stablehlo import ( + scaled_matmul_wrapper as cudnn_scaled_matmul, + scaled_dot_general_wrapper as cudnn_scaled_dot_general, + BlockScaleConfig) from jax._src import core from jax._src import deprecations from jax._src import dtypes @@ -39,7 +43,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.numpy import util as numpy_util -from jax._src.typing import Array, ArrayLike, DType +from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.ops.special import logsumexp as _logsumexp @@ -1159,3 +1163,107 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], raise ValueError(f"Unsupported implementation option: {implementation}") return jnp.reshape(out, output_shape) + +def scaled_matmul( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = jnp.float32, +) -> Array: + r""" + Performs scaled matrix multiplication between two 3D arrays, with scaling + factors applied to the matrices. + .. math:: + \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + Args: + lhs (Array): A 3D array of shape (B, M, K). + rhs (Array): A 3D array of shape (B, N, K). + lhs_scales (Array): A 3D array of shape (B, M, K_block). + rhs_scales (Array): A 3D array of shape (B, N, K_block). + preferred_element_type (DTypeLike, optional): The preferred data type + for the computation. Defaults to `jnp.float32`. + Returns: + Array: A 3D array of shape (B, M, N) representing the scaled matrix + multiplication result. + Raises: + AssertionError: If the number of columns in `lhs` (`lhs_K`) does not + match the number of columns in `rhs` (`rhs_K`). + Notes: + - The function ensures that the `preferred_element_type` is + danonicalized before passing it to the underlying computation. + - Scaling is applied to the matrices based on the `lhs_scales` and + `rhs_scales` arrays, enabling efficient computations in blocks. + """ + B, M, lhs_K = lhs.shape + _, N, rhs_K = rhs.shape + assert lhs_K == rhs_K + _, _, K_block = lhs_scales.shape + + preferred_element_type = dtypes.canonicalize_dtype( + np.dtype(preferred_element_type) + ) + out = cudnn_scaled_matmul( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=preferred_element_type, + ) + return out + +def scaled_dot_general( + lhs, rhs, + dimension_numbers, + preferred_element_type=jnp.float32, + configs: List[BlockScaleConfig] | None = None, + implementation: Literal['cudnn'] | None = None, + ): + r"""Scaled dot general operation. + Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: + .. math:: + \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ + \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ + \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + Args: + lhs: Left-hand side input tensor. + rhs: Right-hand side input tensor. + dimension_numbers: A tuple specifying the contraction and batch dimensions + for the dot general operation. Must follow the format: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type: The preferred output data type. Supported types are + `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. + configs: A list of `BlockScaleConfig` specifying the scaling + configurations for the operation. Defaults to `mxfp8`. + implementation: A string to control which implementation backend to use. + Supported strings are `cudnn` (cuDNN block scaled dot). It defaults + to `None`, which will automatically select the best available backend. + Returns: + The result of the scaled dot general operation. + """ + # Create configs if not provided + if configs is None: + mxfp8_config = BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + configs = [mxfp8_config for _ in range(3)] + + if implementation is None: + implementation = 'cudnn' + + match implementation: + case 'cudnn': + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) + case _: + raise ValueError(f"Unsupported implementation option: {implementation}") + + return out diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index ebe725c448ee..3f08e1c0fd12 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -37,6 +37,8 @@ relu as relu, relu6 as relu6, dot_product_attention as dot_product_attention, + scaled_dot_general as scaled_dot_general, + scaled_matmul as scaled_matmul, selu as selu, sigmoid as sigmoid, soft_sign as soft_sign, diff --git a/tests/nn_test.py b/tests/nn_test.py index 1f032b3f03cd..f322e3f9979e 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -27,8 +27,15 @@ from jax._src import ad_checkpoint from jax._src import config from jax._src import core +from jax._src import dtypes as _dtypes from jax._src import test_util as jtu from jax._src.lib import cuda_versions +from jax._src.cudnn.scaled_matmul_stablehlo import ( + quantize, + quantize_dequantize, + shape_normalization, + BlockScaleConfig, +) from jax.test_util import check_grads from jax import nn from jax import random @@ -37,9 +44,9 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(min_cudnn_version): +def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version): return ( - jtu.is_cuda_compute_capability_at_least("8.0") and + jtu.is_cuda_compute_capability_at_least(min_cc) and cuda_versions is not None and cuda_versions.cudnn_get_version() >= min_cudnn_version ) @@ -51,9 +58,140 @@ def _check_cudnn_backend(fn, *args, **kwargs): _cudnn_dbias_error = 'cuDNN only supports bias gradient' +def _generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + configs, dtype=jnp.float32, + ): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=dtype, + ) + + k1, k2 = jax.random.split(jax.random.key(123), 2) + + a = cast_to_representable( + jax.random.uniform( + k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform( + k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[1].data_type, + ) + + dn = ((2,), (0,)) + a_3d = shape_normalization(a, dn) + b_3d = shape_normalization(b, dn) + a_q, a_scales = quantize(a, configs[0]) + b_q, b_scales = quantize(b, configs[1]) + + return a, b, a_q, b_q, a_scales, b_scales + +def create_mxfp8_configs_if_available(): + if _dtypes.float8_e8m0fnu is None: + raise unittest.SkipTest("float8_e8m0fnu is not available.") + + def _create_mxfp8_config(): + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + + return [_create_mxfp8_config() for _ in range(3)] + + @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): + @parameterized.product( + contract=[160, 96], + lhs_non_contract=[240, 100], + dtype=[jnp.float16, jnp.bfloat16, jnp.float32], + impl=['cudnn',], + ) + def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") + # Check if float8_e8m0fnu is available + configs = create_mxfp8_configs_if_available() + batch, rhs_non_contract = 4, 256 + a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + configs, dtype=dtype, + ) + out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales, + preferred_element_type=dtype) + out_ref = jnp.matmul(a.astype(jnp.float32), + jnp.transpose(b, (0, 2, 1)).astype(jnp.float32)) + self.assertArraysAllClose( + out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3 + ) + + @parameterized.product( + is_training=[True, False], + output_type=[jnp.float16, jnp.bfloat16, jnp.float32], + impl=['cudnn',], + ) + def testScaledDotGeneral( + self, is_training, output_type, impl): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") + + configs = create_mxfp8_configs_if_available() + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32, + ) + k1, k2 = jax.random.split(jax.random.key(0), 2) + a_shape = [2, 256, 96] + b_shape = [2, 96, 160] + dimension_numbers = (([2], [1]), ([0], [0])) + a = cast_to_representable( + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type), + configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type), + configs[1].data_type, + ) + + scaled_dot_general_fn = partial( + nn.scaled_dot_general, configs=configs + ) + def fwd(a, b, is_ref=False): + fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn + y = fn(a, b, dimension_numbers, + preferred_element_type=output_type) + return jnp.sum(y) + + if is_training: + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]) + ) + out, (x_grad, w_grad) = j_train(a, b) + out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b) + + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + else: + j_inference = jax.jit(fwd) + j_inference_ref = jax.jit(partial(fwd, is_ref=True)) + out = j_inference(a, b) + out_ref = j_inference_ref(a, b) + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + @parameterized.product( dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], @@ -61,7 +199,7 @@ class NNFunctionsTest(jtu.JaxTestCase): impl=['cudnn', 'xla'], ) def testDotProductAttention(self, dtype, group_num, use_vmap, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") @@ -110,7 +248,7 @@ def testDotProductAttentionMask(self, mask_mode): if isinstance(mask_mode, str): mask_mode = (mask_mode,) min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 - if not _is_required_cudnn_version_satisfied(min_cudnn_version): + if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") dtype = jnp.bfloat16 @@ -173,7 +311,7 @@ def testDotProductAttentionMask(self, mask_mode): use_vmap=[False, True], ) def testDotProductAttentionBiasGradient(self, batch_size, use_vmap): - if not _is_required_cudnn_version_satisfied(8904): + if not _is_required_cudnn_version_satisfied("8.0", 8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") dtype = jnp.bfloat16 diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py new file mode 100644 index 000000000000..056bbdc44c7f --- /dev/null +++ b/tests/scaled_matmul_stablehlo_test.py @@ -0,0 +1,503 @@ +from functools import partial +from absl.testing import absltest + +import re +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import PartitionSpec, NamedSharding +from jax._src import config +from jax._src import test_util as jtu +from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version +from jax._src.cudnn.scaled_matmul_stablehlo import ( + scaled_matmul_wrapper, + scaled_dot_general_wrapper, + shape_normalization, + quantize, + quantize_dequantize, + BlockScaleConfig, +) + + +config.parse_flags_with_absl() +input_shardings = [ + (("dp", None, "tp"), ("dp", None, "tp")), + (("dp", None, "tp"), ("dp", None, None)), + (("dp", None, "tp"), ("dp", "tp", None)), + (("dp", None, None), ("dp", "tp", None)), + (("dp", "tp", None), ("dp", "tp", None)), + ((None, "dp", "tp"), (None, "dp", "tp")), + ((None, "tp", None), (None, "tp", None)), + ((None, None, "tp"), (None, "tp", None)), +] +c_name = "__cudnn$blockScaledDot" +expected_hlos = [ + (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), + ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), + (c_name,), + ("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name), + (c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"), + ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), +] +expected_output_spec = [ + PartitionSpec('dp',), + PartitionSpec('dp',), + PartitionSpec('dp', None, 'tp'), + PartitionSpec('dp', None, 'tp'), + PartitionSpec('dp', 'tp', None), + PartitionSpec(None, 'dp', 'tp'), + PartitionSpec(None, 'tp', None), + PartitionSpec(None, None, 'tp'), +] +sharding_configs = { + input_sharding: (hlo, output_spec) + for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec) +} + +def create_mxfp8_configs_if_available(): + if _dtypes.float8_e8m0fnu is None: + raise unittest.SkipTest("float8_e8m0fnu is not available.") + + def _create_mxfp8_config(): + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + + return [_create_mxfp8_config() for _ in range(3)] + +def generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + configs, dtype=jnp.float32, + ): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=dtype, + ) + + k1, k2 = jax.random.split(jax.random.key(123), 2) + + a = cast_to_representable( + jax.random.uniform( + k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform( + k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[1].data_type, + ) + + dn = ((2,), (0,)) + a_3d = shape_normalization(a, dn) + b_3d = shape_normalization(b, dn) + a_q, a_scales = quantize(a, configs[0]) + b_q, b_scales = quantize(b, configs[1]) + + return a, b, a_q, b_q, a_scales, b_scales + + +def shard_and_device_put( + mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None + ): + a_spec = PartitionSpec(*a_sharding) + b_spec = PartitionSpec(*b_sharding) + + a_named_sharding = NamedSharding(mesh, a_spec) + b_named_sharding = NamedSharding(mesh, b_spec) + + a = jax.device_put(a, a_named_sharding) + b = jax.device_put(b, b_named_sharding) + if a_scales is not None: + a_scales = jax.device_put(a_scales, a_named_sharding) + if b_scales is not None: + b_scales = jax.device_put(b_scales, b_named_sharding) + + in_shardings = ( + a_named_sharding, + b_named_sharding, + ) + if a_scales is not None and b_scales is not None: + in_shardings = ( + a_named_sharding, + b_named_sharding, + a_named_sharding, + b_named_sharding, + ) + return a, b, a_scales, b_scales, in_shardings + + return a, b, in_shardings + + +def get_hlo_text(in_shardings, block_scale_configs=None): + if block_scale_configs is None: + block_scale_configs = create_mxfp8_configs_if_available() + + mesh_names = ("dp", "tp") + devices = np.array(jax.local_devices()[:4]).reshape((2, 2)) + mesh = Mesh(devices, mesh_names) + _, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( + 2, 512, 1024, 512, block_scale_configs, + ) + + with mesh: + a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put( + mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales + ) + pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings) + hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile() + return hlo.as_text() + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class ScaledMatmulTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 90700: + self.skipTest("Requires >= cuDNN 9.7.0") + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Requires at least Blackwell arch") + + @jtu.sample_product( + in_shardings=sharding_configs, + ) + @jtu.run_on_devices("cuda") + def test_collectives(self, in_shardings): + if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4: + self.skipTest("Partition Test enabled for at least 4 GPUs") + + expected_hlo = sharding_configs[in_shardings][0] + hlo_text = get_hlo_text(in_shardings) + + hlo_pattern = re.compile( + r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL + ) + self.assertRegex( + hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}" + ) + + @jtu.sample_product( + contract=[160, 96], + lhs_non_contract=[240, 100], + dtype=[jnp.float16, jnp.bfloat16, jnp.float32], + block_scale_configs=[create_mxfp8_configs_if_available(),], + ) + @jtu.run_on_devices("cuda") + def test_scaled_matmul( + self, contract, lhs_non_contract, dtype, block_scale_configs, + ): + batch, rhs_non_contract = 2, 128 + a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + block_scale_configs, dtype=dtype, + ) + + def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): + return scaled_matmul_wrapper( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=out_type, + ) + + j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype)) + hlo_text = ( + j_scaled_matmul.lower(a_q, b_q, a_scales, b_scales) + .compile() + .as_text() + ) + hlo_pattern = re.compile( + r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + ) + self.assertRegex(hlo_text, hlo_pattern) + + out = j_scaled_matmul(a_q, b_q, a_scales, b_scales) + out_ref = np.einsum( + "BMK,BNK->BMN", a.astype(jnp.float32), b.astype(jnp.float32) + ) + self.assertArraysAllClose( + out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3 + ) + + @jtu.sample_product( + in_shardings=sharding_configs, + block_scale_configs=[create_mxfp8_configs_if_available(),], + ) + @jtu.run_on_devices("cuda") + def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs): + if len(jax.local_devices()) < 4: + self.skipTest("Require at least 4 devices to run sharding tests.") + batch, contract, non_contract = 2, 1024, 256 + a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( + batch, non_contract, contract, non_contract, block_scale_configs, + ) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + expected_output_spec = sharding_configs[in_shardings][1] + + with Mesh(devices, ("dp", "tp")) as mesh: + a_q, b_q, a_scales, b_scales, input_shardings = ( + shard_and_device_put( + mesh, + in_shardings[0], + in_shardings[1], + a_q, + b_q, + a_scales, + b_scales, + ) + ) + + args = [a_q, b_q, a_scales, b_scales] + j_scaled_matmul = jax.jit( + scaled_matmul_wrapper, in_shardings=input_shardings + ) + hlo_compiled = j_scaled_matmul.lower(*args).compile() + hlo_pattern = re.compile( + r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + ) + self.assertRegex(hlo_compiled.as_text(), hlo_pattern) + + j_ref = jax.jit( + partial( + jax.lax.dot_general, + dimension_numbers=(([2], [2]), ([0], [0])), + ), + in_shardings=input_shardings[:2], + ) + + out = j_scaled_matmul(*args) + out_ref = j_ref(a, b) + expected_output_sharding = NamedSharding( + mesh=mesh, spec=expected_output_spec + ) + self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3) + self.assertTrue( + out.sharding.is_equivalent_to(expected_output_sharding, out.ndim) + ) + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase): + + block_scale_configs = create_mxfp8_configs_if_available() + + def setUp(self): + super().setUp() + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 90700: + self.skipTest("Requires >= cuDNN 9.7.0") + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Requires at least Blackwell arch") + + @jtu.sample_product( + configs=[ + # a_shape, b_shape, dimension_numbers, is_training + ((1, 32), (2, 32), (([1], [1]), ([], [])), False), + ((30, 64), (100, 64), (([1], [1]), ([], [])), False), + ((192, 96), (160, 96), (([1], [1]), ([], [])), True), + ((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True), + ((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True), + ( + (1, 128, 128, 2), + (128, 1, 2, 128), + (([2], [0]), ([0, 3], [1, 2])), + True, + ), + ], + output_type=[jnp.float16, jnp.bfloat16, jnp.float32], + ) + @jtu.run_on_devices("cuda") + def test_dot_general(self, configs, output_type): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32, + ) + k1, k2 = jax.random.split(jax.random.key(0), 2) + + a_shape, b_shape, dimension_numbers, is_training = configs + a = cast_to_representable( + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type), + self.block_scale_configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type), + self.block_scale_configs[1].data_type, + ) + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) + def fwd(a, b, is_ref=False): + fn = jax.lax.dot_general if is_ref else scaled_dot_general + y = fn(a, b, dimension_numbers, + preferred_element_type=output_type) + return jnp.sum(y) + + if is_training: + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]) + ) + out, (x_grad, w_grad) = j_train(a, b) + out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b) + + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + else: + j_inference = jax.jit(fwd) + j_inference_ref = jax.jit(partial(fwd, is_ref=True)) + out = j_inference(a, b) + out_ref = j_inference_ref(a, b) + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + + @jtu.sample_product(in_shardings=sharding_configs) + @jtu.run_on_devices("cuda") + def test_dot_general_sharded(self, in_shardings): + if len(jax.local_devices()) < 4: + self.skipTest("Require at least 4 devices to run sharding tests.") + + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32, + ) + + dimension_numbers = (([2], [2]), ([0], [0])) + a_shape = (2, 128, 512) + b_shape = (2, 256, 512) + + k1, k2 = jax.random.split(jax.random.key(0), 2) + a = cast_to_representable( + jax.random.uniform(k1, a_shape, minval=-1.0), + self.block_scale_configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform(k2, b_shape, minval=-1.0), + self.block_scale_configs[1].data_type, + ) + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) + def fwd(a, b, is_ref=False): + fn = jax.lax.dot_general if is_ref else scaled_dot_general + y = fn(a, b, dimension_numbers) + # Use a little complex loss function to avoid constant grads, whose + # sharding info might be optimized off and then cause issue with the + # custom scaled_matmul op. + return jnp.sum(jnp.tanh(y)) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + with Mesh(devices, ("dp", "tp")) as mesh: + a, b, input_shardings = ( + shard_and_device_put( + mesh, + in_shardings[0], + in_shardings[1], + a, + b, + ) + ) + + j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]), + in_shardings=input_shardings) + hlo_text = j_train.lower(a, b).compile().as_text() + hlo_pattern = re.compile( + r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + ) + + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]), + in_shardings=input_shardings + ) + out, (x_grad, w_grad) = j_train(a, b) + out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b) + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + + + @jtu.sample_product( + configs=[ + ((1, 128, 256), (1, 128, 256), (0, 0, 0)), + ((2, 128, 128), (2, 128, 128), (0, 0, 0)), + ((2, 128, 128), (128, 2, 128), (0, 1, 2)), + ] + ) + @jtu.run_on_devices("cuda") + def test_dot_general_vmap(self, configs): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32, + ) + k1, k2 = jax.random.split(jax.random.key(0), 2) + + a_shape, b_shape, vmap_axes = configs + a_axis, b_axis, o_axis = vmap_axes + dimension_numbers = (([1], [1]), ([], [])) + + a = cast_to_representable( + jax.random.uniform(k1, a_shape, minval=-1.0), + self.block_scale_configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform(k2, b_shape, minval=-1.0), + self.block_scale_configs[1].data_type, + ) + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) + def fwd(a, b, is_ref=False): + fn = jax.vmap( + jax.lax.dot_general if is_ref else scaled_dot_general, + in_axes=(a_axis, b_axis, None), + out_axes=o_axis, + ) + y = fn(a, b, dimension_numbers) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]) + ) + out, (x_grad, w_grad) = j_train(a, b) + out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b) + + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())