diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c9efa6d54681..036820d6cc6e 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1324,452 +1324,115 @@ def _householder_product_cpu_gpu_lowering(ctx, a, taus, *, householder_product_p, _householder_product_cpu_gpu_lowering) -# Symmetric product - -def _symmetric_product_shape_rule(a_shape, c_shape, **_): - if a_shape[0] != c_shape[1] or c_shape[0] != c_shape[1]: - raise ValueError( - "symmetric_update expects a rectangular matrix of shape (m, n) and a " - f"square matrix of shape (n, n). Got shapes {a_shape} and {c_shape}.") - return c_shape - -def _symmetric_product_jax_fn(a, c, *, alpha, beta): - a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2)) - return alpha * lax.batch_matmul( - a, a_T, precision=lax.Precision.HIGHEST) + beta * c +# LU decomposition -def _symmetric_product_gpu_lowering( - platform, ctx, a_tensor, c_tensor, alpha, beta): - a_aval, c_aval = ctx.avals_in[:2] - dtype = a_aval.dtype - alpha_aval = beta_aval = ShapedArray((), dtype) +# Computes a pivoted LU decomposition such that +# PA = LU +# In the style of LAPACK, LU are stored in the same matrix. - alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) - beta_array = mlir.full_like_aval(ctx, beta, beta_aval) +def _lu_unblocked(a): + """Unblocked LU decomposition, as a rolled loop.""" + m, n = a.shape + def body(k, state): + pivot, perm, a = state + m_idx = lax.iota('int32', m) + n_idx = lax.iota('int32', n) - rule = ffi.ffi_lowering(f"{platform}solver_syrk_ffi", - operand_output_aliases={1: 0}) - ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) - return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) + if dtypes.issubdtype(a.dtype, np.complexfloating): + t = a[:, k] + magnitude = abs(t.real) + abs(t.imag) + else: + magnitude = abs(a[:, k]) + i = lax.argmax(lax.select(m_idx >= k, magnitude, lax.full_like(magnitude, -np.inf)), + axis=0, index_dtype=pivot.dtype) + pivot = pivot.at[k].set(i) + a = a.at[[k, i],].set(a[[i, k],]) + perm = perm.at[[i, k],].set(perm[[k, i],]) -symmetric_product_p = standard_linalg_primitive( - (_float, _float), (2, 2), _symmetric_product_shape_rule, - "symmetric_product") -mlir.register_lowering( - symmetric_product_p, - partial(_symmetric_product_gpu_lowering, "cu"), platform="cuda") -mlir.register_lowering( - symmetric_product_p, - mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False)) + # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes + x = a[k, k] + a = a.at[:, k].set(lax.select((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) + # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) + a_outer = a[:, k, None] * a[k, None] + a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), + a_outer, lax_internal._zeros(a_outer)) + return pivot, perm, a -# Triangular solve + pivot = lax.full((min(m, n),), 0, dtype=np.int32) + perm = lax.iota('int32', m) + if m == 0 and n == 0: + # If the array is empty, the loop body never executes but tracing it to a + # jaxpr fails because the indexing cannot succeed. + return (pivot, perm, a) + return lax.fori_loop(0, min(m, n), body, (pivot, perm, a)) -_triangular_solve_dtype_rule = partial( - naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex), - 'triangular_solve') -def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs): - if a.ndim < 2: - msg = "triangular_solve requires a.ndim to be at least 2, got {}." - raise TypeError(msg.format(a.ndim)) - if b.ndim < 2: - msg = "triangular_solve requires b.ndim to be at least 2, got {}." - raise TypeError(msg.format(b.ndim)) - if a.shape[-1] != a.shape[-2]: - msg = ("triangular_solve requires the last two dimensions of a to be equal " - "in size, got a.shape of {}.") - raise TypeError(msg.format(a.shape)) - if a.shape[:-2] != b.shape[:-2]: - msg = ("triangular_solve requires both arguments to have the same number " - "of dimensions and equal batch dimensions, got {} and {}.") - raise TypeError(msg.format(a.shape, b.shape)) - common_dim = -2 if left_side else -1 - if a.shape[-1] != b.shape[common_dim]: - msg = "Incompatible shapes for arguments to triangular_solve: {} and {}." - raise TypeError(msg.format(a.shape, b.shape)) - return b.shape +def _lu_blocked(a, block_size=128): + """Blocked LU decomposition, as an unrolled loop.""" + m, n = a.shape + r = min(m, n) + pivot = lax.full((r,), 0, dtype=np.int32) + perm = lax.iota('int32', m) + for k in range(0, r, block_size): + b = min(r - k, block_size) + block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b]) -def _triangular_solve_sharding_rule(a, b, *, left_side=False, **unused_kwargs): - a_spec, b_spec = a.sharding.spec, b.sharding.spec - if a_spec[-1] != a_spec[-2]: - raise TypeError( - "triangular_solve requires the last two dimensions of a to be equal " - f"in sharding, got a_spec of {a_spec}.") - if a_spec[:-2] != b_spec[:-2]: - raise TypeError( - "triangular_solve requires both arguments to have the same number " - f"of dimensions and equal batch shardings, got {a_spec} and {b_spec}.") - common_dim = -2 if left_side else -1 - if a_spec[-1] != b_spec[common_dim]: - raise TypeError( - "Incompatible shardings for arguments to triangular_solve:" - f" {a_spec} and {b_spec}.") - return b.sharding + pivot = pivot.at[k:k+b].set(block_pivot + k) + perm = perm.at[k:].set(perm[block_perm + k]) + a = a.at[k:, :].set(a[block_perm + k, :]) + a = a.at[k:, k:k+b].set(lu_block) + if k + b < n: + a = a.at[k:k+b, k+b:].set( + triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True, + lower=True, unit_diagonal=True)) + a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:], + precision=lax.Precision.HIGHEST)) + return a, pivot, perm -def _triangular_solve_jvp_rule_a( - g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a, - unit_diagonal): - m, n = b.shape[-2:] - k = 1 if unit_diagonal else 0 - g_a = _tril(g_a, k=-k) if lower else _triu(g_a, k=k) - g_a = lax.neg(g_a) - g_a = _T(g_a) if transpose_a else g_a - g_a = g_a.conj() if conjugate_a else g_a - dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, - precision=lax.Precision.HIGHEST) +def _lu_python(x): + """Default LU decomposition in Python, where no better version exists.""" + batch_dims = x.shape[:-2] + fn = _lu_blocked + for _ in range(len(batch_dims)): + fn = api.vmap(fn) - def a_inverse(rhs): - return triangular_solve(a, rhs, left_side=left_side, lower=lower, - transpose_a=transpose_a, conjugate_a=conjugate_a, - unit_diagonal=unit_diagonal) + return fn(x) - # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs - # for matrix/vector inputs). Order these operations in whichever order is - # cheaper. - if left_side: - assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n) - if m > n: - return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) - else: - return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X - else: - assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n) - if m < n: - return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} - else: - return dot(ans, a_inverse(g_a)) # X (∂A A^{-1}) -def _triangular_solve_transpose_rule( - cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a, - unit_diagonal): - # Triangular solve is nonlinear in its first argument and linear in its second - # argument, analogous to `div` but swapped. - assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b) - if type(cotangent) is ad_util.Zero: - cotangent_b = ad_util.Zero(b.aval) - else: - cotangent_b = triangular_solve(a, cotangent, left_side=left_side, - lower=lower, transpose_a=not transpose_a, - conjugate_a=conjugate_a, - unit_diagonal=unit_diagonal) - return [None, cotangent_b] +def _lu_shape_rule(shape): + m, n = shape + return shape, (core.min_dim(m, n),), (m,) -def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side, - lower, transpose_a, conjugate_a, - unit_diagonal): - x, y = batched_args - bx, by = batch_dims - if bx is batching.not_mapped: - if left_side: - y = batching.moveaxis(y, by, -1) - y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1],)) - bdim_out = y.ndim - 1 - else: - y = batching.moveaxis(y, by, -2) - y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1])) - bdim_out = y.ndim - 2 - out_flat = triangular_solve( - x, y_flat, left_side=left_side, lower=lower, - transpose_a=transpose_a, conjugate_a=conjugate_a, - unit_diagonal=unit_diagonal) - return out_flat.reshape(y.shape), bdim_out - else: - size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) - if i is not None) - x = batching.bdim_at_front(x, bx, size) - y = batching.bdim_at_front(y, by, size) - return triangular_solve(x, y, left_side=left_side, lower=lower, - transpose_a=transpose_a, conjugate_a=conjugate_a, - unit_diagonal=unit_diagonal), 0 +def _lu_dtype_rule(dtype, **_): + dtype = dtypes.canonicalize_dtype(dtype) + return dtype, dtypes.dtype(np.int32), dtypes.dtype(np.int32) -triangular_solve_p = standard_primitive( - _triangular_solve_shape_rule, _triangular_solve_dtype_rule, - 'triangular_solve', sharding_rule=_triangular_solve_sharding_rule) -ad.defjvp2(triangular_solve_p, - _triangular_solve_jvp_rule_a, - lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws)) -ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule -batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule +def _lu_jvp_rule(primals, tangents): + a, = primals + a_dot, = tangents + lu, pivots, permutation = lu_p.bind(a) -def _triangular_solve_lowering( - ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal): - out_aval, = ctx.avals_out - if conjugate_a and not transpose_a: - a = chlo.ConjOp(a) - conjugate_a = False - if not transpose_a: - transpose = "NO_TRANSPOSE" - else: - transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose)) - return [mlir.lower_sharding_under_shit(ctx, out, out_aval)] + a_shape = np.shape(a) + m, n = a_shape[-2:] + dtype = lax.dtype(a) + k = min(m, n) -_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64), - np.dtype(np.complex64), np.dtype(np.complex128)} + batch_dims = a_shape[:-2] + iotas = _broadcasted_iotas(*batch_dims, 1) + x = a_dot[(*iotas[:-1], permutation, slice(None))] -_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64), - np.dtype(np.complex64), np.dtype(np.complex128)} - -def _triangular_solve_cpu_lower( - ctx, a, b, *, left_side, lower, transpose_a, - conjugate_a, unit_diagonal): - a_aval, b_aval = ctx.avals_in - - if conjugate_a and not transpose_a: - a = chlo.conj(a) - conjugate_a = False - if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: - target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) - alpha_aval = ShapedArray((), a_aval.dtype) - rule = _linalg_ffi_lowering(target_name, - [a_aval, b_aval, alpha_aval], - operand_output_aliases={1: 0}) - return rule(ctx, a, b, alpha, - side=_matrix_side_attr(left_side), - uplo=_matrix_uplo_attr(lower), - trans_x=_matrix_transpose_attr(transpose_a, conjugate_a), - diag=_matrix_diagonal_attr(unit_diagonal)) - else: - # Fall back to the HLO implementation for unsupported types or batching. - # TODO: Consider swapping XLA for LAPACK in batched case - if transpose_a: - transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - else: - transpose = "NO_TRANSPOSE" - return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose))] - - -mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) -mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, - platform='cpu') - - -# Support operation for LU decomposition: Transformation of the pivots returned -# by LU decomposition into permutations. - -# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits -def _lu_pivots_body_fn_inner(i, permutation, swaps): - j = swaps[i] - x = permutation[i] - y = permutation[j] - permutation = permutation.at[i].set(y) - return permutation.at[j].set(x) - -def _lu_pivots_body_fn(i, permutation_and_swaps): - permutation, swaps = permutation_and_swaps - batch_dims = swaps.shape[:-1] - fn = _lu_pivots_body_fn_inner - for _ in range(len(batch_dims)): - fn = api.vmap(fn, in_axes=(None, 0, 0), out_axes=0) - return fn(i, permutation, swaps), swaps - -def _generic_lu_pivots_to_permutation(swaps, permutation_size): - """Converts the pivots (row swaps) returned by LU to a permutation. - - We build a permutation rather than applying `swaps` directly to the rows - of a matrix because lax loops aren't differentiable. - - Args: - swaps: an array of shape (..., k) of row swaps to perform - permutation_size: the size of the output permutation. Should be >= k. - Returns: - An int32 array of shape (..., m). - """ - assert len(swaps.shape) >= 1 - batch_dims = swaps.shape[:-1] - k = swaps.shape[-1] - m = permutation_size - - permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,), - len(batch_dims)) - if m == 0 or k == 0: - return permutation - upper = np.array(k, np.int32) if is_constant_dim(k) else k - result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, - (permutation, swaps)) - return result - - -def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): - if isinstance(pivots, ShapedArray): - if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): - raise ValueError( - 'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype ' - 'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype)) - pivots_size = pivots.shape[-1] - if not permutation_size >= pivots_size: - raise ValueError( - 'Output permutation size {} has to exceed the trailing dimension of ' - 'the pivots. Got pivots size {}'.format(permutation_size, pivots_size)) - return pivots.update(shape=(*pivots.shape[:-1], permutation_size)) - else: - return pivots - - -def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, - permutation_size): - x, = batched_args - bd, = batch_dims - x = batching.moveaxis(x, bd, 0) - return lu_pivots_to_permutation_p.bind( - x, permutation_size=permutation_size), 0 - -def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *, - permutation_size): - del permutation_size # unused - rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation") - return rule(ctx, pivots) - - -lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') -lu_pivots_to_permutation_p.multiple_results = False -lu_pivots_to_permutation_p.def_impl( - partial(dispatch.apply_primitive, lu_pivots_to_permutation_p)) -lu_pivots_to_permutation_p.def_abstract_eval( - _lu_pivots_to_permutation_abstract_eval) -batching.primitive_batchers[lu_pivots_to_permutation_p] = ( - _lu_pivots_to_permutation_batching_rule) -mlir.register_lowering( - lu_pivots_to_permutation_p, - mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False)) -mlir.register_lowering( - lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, "cu"), - platform='cuda') -mlir.register_lowering( - lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, "hip"), - platform='rocm') - -# LU decomposition - -# Computes a pivoted LU decomposition such that -# PA = LU -# In the style of LAPACK, LU are stored in the same matrix. - -def _lu_unblocked(a): - """Unblocked LU decomposition, as a rolled loop.""" - m, n = a.shape - def body(k, state): - pivot, perm, a = state - m_idx = lax.iota('int32', m) - n_idx = lax.iota('int32', n) - - if dtypes.issubdtype(a.dtype, np.complexfloating): - t = a[:, k] - magnitude = abs(t.real) + abs(t.imag) - else: - magnitude = abs(a[:, k]) - i = lax.argmax(lax.select(m_idx >= k, magnitude, lax.full_like(magnitude, -np.inf)), - axis=0, index_dtype=pivot.dtype) - pivot = pivot.at[k].set(i) - a = a.at[[k, i],].set(a[[i, k],]) - perm = perm.at[[i, k],].set(perm[[k, i],]) - - # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes - x = a[k, k] - a = a.at[:, k].set(lax.select((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) - - # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) - a_outer = a[:, k, None] * a[k, None] - a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), - a_outer, lax_internal._zeros(a_outer)) - return pivot, perm, a - - pivot = lax.full((min(m, n),), 0, dtype=np.int32) - perm = lax.iota('int32', m) - if m == 0 and n == 0: - # If the array is empty, the loop body never executes but tracing it to a - # jaxpr fails because the indexing cannot succeed. - return (pivot, perm, a) - return lax.fori_loop(0, min(m, n), body, (pivot, perm, a)) - - -def _lu_blocked(a, block_size=128): - """Blocked LU decomposition, as an unrolled loop.""" - m, n = a.shape - r = min(m, n) - pivot = lax.full((r,), 0, dtype=np.int32) - perm = lax.iota('int32', m) - for k in range(0, r, block_size): - b = min(r - k, block_size) - block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b]) - - pivot = pivot.at[k:k+b].set(block_pivot + k) - perm = perm.at[k:].set(perm[block_perm + k]) - a = a.at[k:, :].set(a[block_perm + k, :]) - a = a.at[k:, k:k+b].set(lu_block) - - if k + b < n: - a = a.at[k:k+b, k+b:].set( - triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True, - lower=True, unit_diagonal=True)) - a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:], - precision=lax.Precision.HIGHEST)) - return a, pivot, perm - -def _lu_python(x): - """Default LU decomposition in Python, where no better version exists.""" - batch_dims = x.shape[:-2] - fn = _lu_blocked - for _ in range(len(batch_dims)): - fn = api.vmap(fn) - - return fn(x) - -def _lu_impl(operand): - lu, pivot, perm = dispatch.apply_primitive(lu_p, operand) - return lu, pivot, perm - -def _lu_abstract_eval(operand): - if isinstance(operand, ShapedArray): - if operand.ndim < 2: - raise ValueError("Argument to LU decomposition must have ndims >= 2") - - batch_dims = operand.shape[:-2] - m = operand.shape[-2] - n = operand.shape[-1] - pivot = operand.update(shape=batch_dims + (core.min_dim(m, n),), - dtype=np.int32) - perm = operand.update(shape=batch_dims + (m,), dtype=np.int32) - else: - pivot = operand - perm = operand - return operand, pivot, perm - -def _lu_jvp_rule(primals, tangents): - a, = primals - a_dot, = tangents - lu, pivots, permutation = lu_p.bind(a) - - a_shape = np.shape(a) - m, n = a_shape[-2:] - dtype = lax.dtype(a) - k = min(m, n) - - batch_dims = a_shape[:-2] - iotas = _broadcasted_iotas(*batch_dims, 1) - x = a_dot[(*iotas[:-1], permutation, slice(None))] - - # Differentiation of Matrix Functionals Using Triangular Factorization - # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas - # - # LU = A - # ==> L'U + LU' = A' - # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) - # ==> L' = L . tril(inv(L) . A' . inv(U), -1) - # U' = triu(inv(L) . A' . inv(U)) . U + # Differentiation of Matrix Functionals Using Triangular Factorization + # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas + # + # LU = A + # ==> L'U + LU' = A' + # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) + # ==> L' = L . tril(inv(L) . A' . inv(U), -1) + # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims @@ -1797,12 +1460,6 @@ def _lu_jvp_rule(primals, tangents): ad_util.Zero.from_primal_value(permutation)) -def _lu_batching_rule(batched_args, batch_dims): - x, = batched_args - bd, = batch_dims - x = batching.moveaxis(x, bd, 0) - return lu_p.bind(x), (0, 0, 0) - def _lu_cpu_gpu_lowering(ctx, operand, *, target_name_prefix: str): operand_aval, = ctx.avals_in out_aval, pivot_aval, perm_aval = ctx.avals_out @@ -1851,26 +1508,19 @@ def _lu_tpu_lowering_rule(ctx, operand): return op.results -lu_p = Primitive('lu') -lu_p.multiple_results = True -lu_p.def_impl(_lu_impl) -lu_p.def_abstract_eval(_lu_abstract_eval) -mlir.register_lowering(lu_p, mlir.lower_fun(_lu_python, multiple_results=True)) +lu_p = linalg_primitive( + _lu_dtype_rule, (_float | _complex,), (2,), _lu_shape_rule, "lu", + multiple_results=True) ad.primitive_jvps[lu_p] = _lu_jvp_rule -batching.primitive_batchers[lu_p] = _lu_batching_rule - -mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="cpu"), - platform="cpu") +mlir.register_lowering(lu_p, mlir.lower_fun(_lu_python, multiple_results=True)) +mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') +register_cpu_gpu_lowering(lu_p, _lu_cpu_gpu_lowering) -mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="cu"), - platform="cuda") -mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="hip"), - platform="rocm") -mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') +def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike, + trans: int = 0) -> Array: + """LU solve with broadcasting.""" + return _lu_solve(lu, permutation, b, trans) def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array: @@ -1930,11 +1580,311 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: x = fn(lu, permutation, b, trans) return x[..., 0] if rhs_vector else x +# Support operation for LU decomposition: Transformation of the pivots returned +# by LU decomposition into permutations. -def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike, - trans: int = 0) -> Array: - """LU solve with broadcasting.""" - return _lu_solve(lu, permutation, b, trans) +# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits +def _lu_pivots_body_fn_inner(i, permutation, swaps): + j = swaps[i] + x = permutation[i] + y = permutation[j] + permutation = permutation.at[i].set(y) + return permutation.at[j].set(x) + + +def _lu_pivots_body_fn(i, permutation_and_swaps): + permutation, swaps = permutation_and_swaps + batch_dims = swaps.shape[:-1] + fn = _lu_pivots_body_fn_inner + for _ in range(len(batch_dims)): + fn = api.vmap(fn, in_axes=(None, 0, 0), out_axes=0) + return fn(i, permutation, swaps), swaps + + +def _generic_lu_pivots_to_permutation(swaps, permutation_size): + """Converts the pivots (row swaps) returned by LU to a permutation. + + We build a permutation rather than applying `swaps` directly to the rows + of a matrix because lax loops aren't differentiable. + + Args: + swaps: an array of shape (..., k) of row swaps to perform + permutation_size: the size of the output permutation. Should be >= k. + Returns: + An int32 array of shape (..., m). + """ + assert len(swaps.shape) >= 1 + batch_dims = swaps.shape[:-1] + k = swaps.shape[-1] + m = permutation_size + + permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,), + len(batch_dims)) + if m == 0 or k == 0: + return permutation + upper = np.array(k, np.int32) if is_constant_dim(k) else k + result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, + (permutation, swaps)) + return result + + +def _lu_pivots_to_permutation_shape_rule(shape, *, permutation_size): + pivots_size, = shape + if not permutation_size >= pivots_size: + raise ValueError( + f"Output permutation size {permutation_size} has to exceed the " + f"trailing dimension of the pivots. Got pivots size {pivots_size}") + return (permutation_size,) + + +def _lu_pivots_to_permutation_gpu_lowering(ctx, pivots, *, + permutation_size, + target_name_prefix): + del permutation_size # unused + rule = ffi.ffi_lowering(f"{target_name_prefix}_lu_pivots_to_permutation") + return rule(ctx, pivots) + + +lu_pivots_to_permutation_p = standard_linalg_primitive( + ({np.int32},), (1,), _lu_pivots_to_permutation_shape_rule, + "lu_pivots_to_permutation") +mlir.register_lowering( + lu_pivots_to_permutation_p, + mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False)) +register_cpu_gpu_lowering( + lu_pivots_to_permutation_p, _lu_pivots_to_permutation_gpu_lowering, + ("cuda", "rocm")) + + +# Symmetric product + +def _symmetric_product_shape_rule(a_shape, c_shape, **_): + if a_shape[0] != c_shape[1] or c_shape[0] != c_shape[1]: + raise ValueError( + "symmetric_update expects a rectangular matrix of shape (m, n) and a " + f"square matrix of shape (n, n). Got shapes {a_shape} and {c_shape}.") + return c_shape + +def _symmetric_product_jax_fn(a, c, *, alpha, beta): + a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2)) + return alpha * lax.batch_matmul( + a, a_T, precision=lax.Precision.HIGHEST) + beta * c + +def _symmetric_product_gpu_lowering( + platform, ctx, a_tensor, c_tensor, alpha, beta): + a_aval, c_aval = ctx.avals_in[:2] + dtype = a_aval.dtype + alpha_aval = beta_aval = ShapedArray((), dtype) + + alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) + beta_array = mlir.full_like_aval(ctx, beta, beta_aval) + + rule = ffi.ffi_lowering(f"{platform}solver_syrk_ffi", + operand_output_aliases={1: 0}) + ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) + return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) + +symmetric_product_p = standard_linalg_primitive( + (_float, _float), (2, 2), _symmetric_product_shape_rule, + "symmetric_product") +mlir.register_lowering( + symmetric_product_p, + partial(_symmetric_product_gpu_lowering, "cu"), platform="cuda") +mlir.register_lowering( + symmetric_product_p, + mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False)) + + +# Triangular solve + +_triangular_solve_dtype_rule = partial( + naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex), + 'triangular_solve') + +def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs): + if a.ndim < 2: + msg = "triangular_solve requires a.ndim to be at least 2, got {}." + raise TypeError(msg.format(a.ndim)) + if b.ndim < 2: + msg = "triangular_solve requires b.ndim to be at least 2, got {}." + raise TypeError(msg.format(b.ndim)) + if a.shape[-1] != a.shape[-2]: + msg = ("triangular_solve requires the last two dimensions of a to be equal " + "in size, got a.shape of {}.") + raise TypeError(msg.format(a.shape)) + if a.shape[:-2] != b.shape[:-2]: + msg = ("triangular_solve requires both arguments to have the same number " + "of dimensions and equal batch dimensions, got {} and {}.") + raise TypeError(msg.format(a.shape, b.shape)) + common_dim = -2 if left_side else -1 + if a.shape[-1] != b.shape[common_dim]: + msg = "Incompatible shapes for arguments to triangular_solve: {} and {}." + raise TypeError(msg.format(a.shape, b.shape)) + return b.shape + +def _triangular_solve_sharding_rule(a, b, *, left_side=False, **unused_kwargs): + a_spec, b_spec = a.sharding.spec, b.sharding.spec + if a_spec[-1] != a_spec[-2]: + raise TypeError( + "triangular_solve requires the last two dimensions of a to be equal " + f"in sharding, got a_spec of {a_spec}.") + if a_spec[:-2] != b_spec[:-2]: + raise TypeError( + "triangular_solve requires both arguments to have the same number " + f"of dimensions and equal batch shardings, got {a_spec} and {b_spec}.") + common_dim = -2 if left_side else -1 + if a_spec[-1] != b_spec[common_dim]: + raise TypeError( + "Incompatible shardings for arguments to triangular_solve:" + f" {a_spec} and {b_spec}.") + return b.sharding + + +def _triangular_solve_jvp_rule_a( + g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a, + unit_diagonal): + m, n = b.shape[-2:] + k = 1 if unit_diagonal else 0 + g_a = _tril(g_a, k=-k) if lower else _triu(g_a, k=k) + g_a = lax.neg(g_a) + g_a = _T(g_a) if transpose_a else g_a + g_a = g_a.conj() if conjugate_a else g_a + dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, + precision=lax.Precision.HIGHEST) + + def a_inverse(rhs): + return triangular_solve(a, rhs, left_side=left_side, lower=lower, + transpose_a=transpose_a, conjugate_a=conjugate_a, + unit_diagonal=unit_diagonal) + + # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs + # for matrix/vector inputs). Order these operations in whichever order is + # cheaper. + if left_side: + assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n) + if m > n: + return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) + else: + return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X + else: + assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n) + if m < n: + return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} + else: + return dot(ans, a_inverse(g_a)) # X (∂A A^{-1}) + +def _triangular_solve_transpose_rule( + cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a, + unit_diagonal): + # Triangular solve is nonlinear in its first argument and linear in its second + # argument, analogous to `div` but swapped. + assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b) + if type(cotangent) is ad_util.Zero: + cotangent_b = ad_util.Zero(b.aval) + else: + cotangent_b = triangular_solve(a, cotangent, left_side=left_side, + lower=lower, transpose_a=not transpose_a, + conjugate_a=conjugate_a, + unit_diagonal=unit_diagonal) + return [None, cotangent_b] + + +def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side, + lower, transpose_a, conjugate_a, + unit_diagonal): + x, y = batched_args + bx, by = batch_dims + if bx is batching.not_mapped: + if left_side: + y = batching.moveaxis(y, by, -1) + y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1],)) + bdim_out = y.ndim - 1 + else: + y = batching.moveaxis(y, by, -2) + y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1])) + bdim_out = y.ndim - 2 + out_flat = triangular_solve( + x, y_flat, left_side=left_side, lower=lower, + transpose_a=transpose_a, conjugate_a=conjugate_a, + unit_diagonal=unit_diagonal) + return out_flat.reshape(y.shape), bdim_out + else: + size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) + if i is not None) + x = batching.bdim_at_front(x, bx, size) + y = batching.bdim_at_front(y, by, size) + return triangular_solve(x, y, left_side=left_side, lower=lower, + transpose_a=transpose_a, conjugate_a=conjugate_a, + unit_diagonal=unit_diagonal), 0 + +triangular_solve_p = standard_primitive( + _triangular_solve_shape_rule, _triangular_solve_dtype_rule, + 'triangular_solve', sharding_rule=_triangular_solve_sharding_rule) +ad.defjvp2(triangular_solve_p, + _triangular_solve_jvp_rule_a, + lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws)) +ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule +batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule + + +def _triangular_solve_lowering( + ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal): + out_aval, = ctx.avals_out + if conjugate_a and not transpose_a: + a = chlo.ConjOp(a) + conjugate_a = False + if not transpose_a: + transpose = "NO_TRANSPOSE" + else: + transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" + out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + hlo.TransposeAttr.get(transpose)) + return [mlir.lower_sharding_under_shit(ctx, out, out_aval)] + +_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64), + np.dtype(np.complex64), np.dtype(np.complex128)} + +_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64), + np.dtype(np.complex64), np.dtype(np.complex128)} + +def _triangular_solve_cpu_lower( + ctx, a, b, *, left_side, lower, transpose_a, + conjugate_a, unit_diagonal): + a_aval, b_aval = ctx.avals_in + + if conjugate_a and not transpose_a: + a = chlo.conj(a) + conjugate_a = False + if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: + target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) + alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) + alpha_aval = ShapedArray((), a_aval.dtype) + rule = _linalg_ffi_lowering(target_name, + [a_aval, b_aval, alpha_aval], + operand_output_aliases={1: 0}) + return rule(ctx, a, b, alpha, + side=_matrix_side_attr(left_side), + uplo=_matrix_uplo_attr(lower), + trans_x=_matrix_transpose_attr(transpose_a, conjugate_a), + diag=_matrix_diagonal_attr(unit_diagonal)) + else: + # Fall back to the HLO implementation for unsupported types or batching. + # TODO: Consider swapping XLA for LAPACK in batched case + if transpose_a: + transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" + else: + transpose = "NO_TRANSPOSE" + return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + hlo.TransposeAttr.get(transpose))] + + +mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) +mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, + platform='cpu') # QR decomposition