Skip to content

Commit

Permalink
Removed use of deprecated jax.core.pp_*
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 21, 2024
1 parent 50c7668 commit 7d07305
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import Any, Optional, Union

import jax
import jax._src.pretty_printer as pp
import jax._src.source_info_util as source_info_util
import jax.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
Expand Down Expand Up @@ -284,43 +282,6 @@ def _noinline_batch(inputs, batch_axes):
return out, jtu.tree_map(lambda _: 0, out)


def _pp_transform(x):
if x is _jvp_transform:
return "jvp"
elif type(x) is _MetaTransposeTransform:
return "transpose"
elif type(x) is _MetaBatchTransform:
return "vmap"
else:
assert False


def _noinline_pretty_print(eqn, context, settings):
_, abstract_fn, transforms, _ = jtu.tree_unflatten(
eqn.params["treedef"], eqn.params["static"]
)
pretty_params = dict(abstract_fn=abstract_fn)
if type(eqn.invars[0]) is jax.core.Literal:
static_fn_leaves, static_fn_treedef = _index_to_fn[eqn.invars[0].val]
pretty_params["static_fn"] = jtu.tree_unflatten(
static_fn_treedef, static_fn_leaves
)
if len(transforms) > 1:
# skip impl
transforms = [_pp_transform(x) for x in transforms[1:]]
pretty_params["transforms"] = transforms
lhs = jax.core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [
pp.text(eqn.primitive.name),
jax.core.pp_kv_pairs(sorted(pretty_params.items()), context, settings),
pp.text(" ") + jax.core.pp_vars(eqn.invars, context),
]
annotation = (
source_info_util.summarize(eqn.source_info) if settings.source_info else None
)
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])


# Not a PyTree
class _MlirWrapper:
def __init__(self, val):
Expand Down Expand Up @@ -376,7 +337,6 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
ad.primitive_jvps[noinline_p] = _noinline_jvp
ad.primitive_transposes[noinline_p] = _noinline_transpose
batching.primitive_batchers[noinline_p] = _noinline_batch
jax.core.pp_eqn_rules[noinline_p] = _noinline_pretty_print
mlir.register_lowering(noinline_p, _noinline_mlir)


Expand Down

0 comments on commit 7d07305

Please sign in to comment.