Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix jax deprecations #1346

Merged
merged 18 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 103 additions & 157 deletions pypesto/objective/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,121 +7,83 @@
"""

import copy
from collections.abc import Sequence
from functools import partial
from typing import Callable
from typing import Union

import numpy as np

from ...C import FVAL, GRAD, HESS, MODE_FUN, RDATAS, ModeType
from ...C import MODE_FUN, ModeType
from ..base import ObjectiveBase, ResultDict

try:
import jax
import jax.experimental.host_callback as hcb
import jax.numpy as jnp
from jax import custom_jvp, grad
from jax import custom_jvp
except ImportError:
raise ImportError(
"Using a jax objective requires an installation of "
"the python package jax. Please install jax via "
"`pip install jax jaxlib`."
) from None

# jax compatible (jittable) objective function using host callback, see
# https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html
# jax compatible (jit-able) objective function using external callback, see
# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html


@partial(custom_jvp, nondiff_argnums=(0,))
def _device_fun(obj: "JaxObjective", x: jnp.array):
"""Jax compatible objective function execution using host callback.

This function does not actually call the underlying objective function,
but instead extracts cached return values. Thus it must only be called
from within obj.call_unprocessed, and obj.cached_base_ret must be populated.
def _device_fun(base_objective: ObjectiveBase, x: jnp.array):
"""Jax compatible objective function execution using external callback.

Parameters
----------
obj:
The wrapped jax objective.
x:
jax computed input array.

Note
----
This function should rather be implemented as class method of JaxObjective,
but this is not possible at the time of writing as this is not supported
by signature inspection in the underlying bind call.
"""
return hcb.call(
obj.cached_fval,
return jax.pure_callback(
partial(base_objective, sensi_orders=(0,)),
jax.ShapeDtypeStruct((), x.dtype),
x,
result_shape=jax.ShapeDtypeStruct((), np.float64),
)


@partial(custom_jvp, nondiff_argnums=(0,))
def _device_fun_grad(obj: "JaxObjective", x: jnp.array):
"""Jax compatible objective gradient execution using host callback.
def _device_fun_value_and_grad(base_objective: ObjectiveBase, x: jnp.array):
"""Jax compatible objective gradient execution using external callback.

This function does not actually call the underlying objective function,
but instead extracts cached return values. Thus it must only be called
from within obj.call_unprocessed and obj.cached_base_ret must be populated.
This function will be called when computing the gradient of the
`JaxObjective` using `jax.grad` or `jax.value_and_grad`. In the latter
case, the function will return both the function value and the gradient,
so no caching is necessary. For higher order derivatives, caching would
be advantageous, but unclear how to implement this.

Parameters
----------
obj:
The wrapped jax objective.
x:
jax computed input array.

Note
----
This function should rather be implemented as class method of JaxObjective,
but this is not possible at the time of writing as this is not supported
by signature inspection in the underlying bind call.
"""
return hcb.call(
obj.cached_grad,
x,
result_shape=jax.ShapeDtypeStruct(
obj.cached_base_ret[GRAD].shape, # bootstrap from cached value
np.float64,
return jax.pure_callback(
partial(
base_objective,
sensi_orders=(
0,
1,
),
),
)


def _device_fun_hess(obj: "JaxObjective", x: jnp.array):
"""Jax compatible objective Hessian execution using host callback.

This function does not actually call the underlying objective function,
but instead extracts cached return values. Thus it must only be called
from within obj.call_unprocessed and obj.cached_base_ret must be populated.

Parameters
----------
obj:
The wrapped jax objective.
x:
jax computed input array.

Note
----
This function should rather be implemented as class method of JaxObjective,
but this is not possible at the time of writing as this is not supported
by signature inspection in the underlying bind call.
"""
return hcb.call(
obj.cached_hess,
x,
result_shape=jax.ShapeDtypeStruct(
obj.cached_base_ret[HESS].shape, # bootstrap from cached value
np.float64,
(
jax.ShapeDtypeStruct((), x.dtype),
jax.ShapeDtypeStruct(
x.shape, # bootstrap from cached value
x.dtype,
),
),
x,
)


# define custom jvp for device_fun & device_fun_grad to enable autodiff, see
# define custom jvp for device_fun to enable autodiff, see
# https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html


Expand All @@ -132,79 +94,43 @@ def _device_fun_jvp(
"""JVP implementation for device_fun."""
(x,) = primals
(x_dot,) = tangents
return _device_fun(obj, x), _device_fun_grad(obj, x).dot(x_dot)


@_device_fun_grad.defjvp
def _device_fun_grad_jvp(
obj: "JaxObjective", primals: jnp.array, tangents: jnp.array
):
"""JVP implementation for device_fun_grad."""
(x,) = primals
(x_dot,) = tangents
return _device_fun_grad(obj, x), _device_fun_hess(obj, x).dot(x_dot)
value, grad = _device_fun_value_and_grad(obj, x)
return value, grad @ x_dot


class JaxObjective(ObjectiveBase):
"""Objective function that combines pypesto objectives with jax functions.
"""Objective function that enables use of pypesto objectives in jax models.

The generated objective function will evaluate objective(jax_fun(x)).
The generated function should generally be compatible with jax, but cannot
compute higher order derivatives and is not vectorized (but still
compatible with jax.vmap)

Parameters
----------
objective:
pyPESTO objective
jax_fun:
jax function (not jitted) that computes input to the pyPESTO objective
pyPESTO objective to be wrapped.

Note
----
Currently only implements MODE_FUN and sensi_orders=(0,). Support for
MODE_RES should be straightforward to add.
"""

def __init__(
self,
objective: ObjectiveBase,
jax_fun: Callable,
x_names: Sequence[str] = None,
):
if not isinstance(objective, ObjectiveBase):
raise TypeError("objective must be an ObjectiveBase instance")
if not objective.check_mode(MODE_FUN):
raise NotImplementedError(
f"objective must support mode={MODE_FUN}"
)
super().__init__(x_names)
self.base_objective = objective

self.jax_fun = jax_fun

# would be cleaner to also have this as class method, but not supported
# by signature inspection in bind call.
def jax_objective(x):
# device fun doesn't actually need the value of y, but we need to
# compute this here for autodiff to work
y = jax_fun(x)
return _device_fun(self, y)

# jit objective & derivatives (not integrated)
self.jax_objective = jax.jit(jax_objective)
self.jax_objective_grad = jax.jit(grad(jax_objective))
self.jax_objective_hess = jax.jit(jax.hessian(jax_objective))

# jit input function
self.infun = jax.jit(self.jax_fun)

# temporary storage for evaluation results of objective
self.cached_base_ret: ResultDict = {}

def cached_fval(self, _):
"""Return cached function value."""
return self.cached_base_ret[FVAL]

def cached_grad(self, _):
"""Return cached gradient."""
return self.cached_base_ret[GRAD]

def cached_hess(self, _):
"""Return cached Hessian."""
return self.cached_base_ret[HESS]
self.jax_objective = partial(_device_fun, self.base_objective)

def check_mode(self, mode: ModeType) -> bool:
"""See `ObjectiveBase` documentation."""
Expand All @@ -215,7 +141,44 @@ def check_sensi_orders(self, sensi_orders, mode: ModeType) -> bool:
if not self.check_mode(mode):
return False
else:
return self.base_objective.check_sensi_orders(sensi_orders, mode)
return (
self.base_objective.check_sensi_orders(sensi_orders, mode)
and max(sensi_orders) == 0
)

def __call__(
self,
x: jnp.ndarray,
sensi_orders: tuple[int, ...] = (0,),
mode: ModeType = MODE_FUN,
return_dict: bool = False,
**kwargs,
) -> Union[jnp.ndarray, tuple, ResultDict]:
"""
See :class:`ObjectiveBase` for more documentation.

Note that this function delegates pre- and post-processing as well as
history handling to the inner objective.
"""

if not self.check_mode(mode):
raise ValueError(
f"This Objective cannot be called with mode" f"={mode}."
)
if not self.check_sensi_orders(sensi_orders, mode):
raise ValueError(
f"This Objective cannot be called with "
f"sensi_orders= {sensi_orders} and mode={mode}."
)

# this computes all the results from the inner objective, rendering
# them accessible as cached values for device_fun, etc.
if kwargs.pop("return_dict", False):
raise ValueError(
"return_dict=True is not available for JaxObjective evaluation"
)

return self.jax_objective(x)

def call_unprocessed(
self,
Expand All @@ -225,48 +188,31 @@ def call_unprocessed(
**kwargs,
) -> ResultDict:
"""
See `ObjectiveBase` for more documentation.
See :class:`ObjectiveBase` for more documentation.

Main method to overwrite from the base class. It handles and
delegates the actual objective evaluation.
This function is not implemented for JaxObjective as it is not called
in the override for __call__. However, it's marked as abstract so we
need to implement it.
"""
# derivative computation in jax always requires lower order
# derivatives, see jvp rules for device_fun and device_fun_grad.
if 2 in sensi_orders:
sensi_orders = (0, 1, 2)
elif 1 in sensi_orders:
sensi_orders = (0, 1)
else:
sensi_orders = (0,)

# this computes all the results from the inner objective, rendering
# them accessible as cached values for device_fun, etc.
set_return_dict, return_dict = (
"return_dict" in kwargs,
kwargs.pop("return_dict", False),
)
self.cached_base_ret = self.base_objective(
self.infun(x), sensi_orders, mode, return_dict=True, **kwargs
)
if set_return_dict:
kwargs["return_dict"] = return_dict
ret = {}
if RDATAS in self.cached_base_ret:
ret[RDATAS] = self.cached_base_ret[RDATAS]
if 0 in sensi_orders:
ret[FVAL] = float(self.jax_objective(x))
if 1 in sensi_orders:
ret[GRAD] = self.jax_objective_grad(x)
if 2 in sensi_orders:
ret[HESS] = self.jax_objective_hess(x)

return ret
pass

def __deepcopy__(self, memodict=None):
other = JaxObjective(
copy.deepcopy(self.base_objective),
copy.deepcopy(self.jax_fun),
copy.deepcopy(self.x_names),
)

return other

@property
def history(self):
"""Exposes the history of the inner objective."""
return self.base_objective.history

@property
def pre_post_processor(self):
"""Exposes the pre_post_processor of inner objective."""
return self.base_objective.pre_post_processor

@property
def x_names(self):
"""Exposes the x_names of inner objective."""
return self.base_objective.x_names
Loading
Loading