From 15289fe5984133b1144f366c34afd474358512b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 28 Mar 2024 09:03:06 +0000 Subject: [PATCH 01/11] fix #1337 --- pypesto/objective/jax/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index c29876fc8..6d05e9080 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -17,7 +17,6 @@ try: import jax - import jax.experimental.host_callback as hcb import jax.numpy as jnp from jax import custom_jvp, grad except ImportError: @@ -52,7 +51,7 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): but this is not possible at the time of writing as this is not supported by signature inspection in the underlying bind call. """ - return hcb.call( + return jax.pure_callback( obj.cached_fval, x, result_shape=jax.ShapeDtypeStruct((), np.float64), @@ -80,7 +79,7 @@ def _device_fun_grad(obj: "JaxObjective", x: jnp.array): but this is not possible at the time of writing as this is not supported by signature inspection in the underlying bind call. """ - return hcb.call( + return jax.pure_callback( obj.cached_grad, x, result_shape=jax.ShapeDtypeStruct( @@ -110,7 +109,7 @@ def _device_fun_hess(obj: "JaxObjective", x: jnp.array): but this is not possible at the time of writing as this is not supported by signature inspection in the underlying bind call. """ - return hcb.call( + return jax.pure_callback( obj.cached_hess, x, result_shape=jax.ShapeDtypeStruct( From 2747c00688e308966d698e58ee25d5d2d478256a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 28 Mar 2024 10:37:01 +0000 Subject: [PATCH 02/11] fix pure callback calls --- pypesto/objective/jax/base.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index 6d05e9080..edd255902 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -26,8 +26,9 @@ "`pip install jax jaxlib`." ) from None -# jax compatible (jittable) objective function using host callback, see -# https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html +# jax compatible (jit-able) objective function using external callback, see +# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html +# note that these functions are impure since they rely on cached values @partial(custom_jvp, nondiff_argnums=(0,)) @@ -35,7 +36,7 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): """Jax compatible objective function execution using host callback. This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus it must only be called + but instead extracts cached return values. Thus, it must only be called from within obj.call_unprocessed, and obj.cached_base_ret must be populated. Parameters @@ -53,8 +54,8 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): """ return jax.pure_callback( obj.cached_fval, - x, - result_shape=jax.ShapeDtypeStruct((), np.float64), + jax.ShapeDtypeStruct((), x.dtype), + (x,), ) @@ -63,7 +64,7 @@ def _device_fun_grad(obj: "JaxObjective", x: jnp.array): """Jax compatible objective gradient execution using host callback. This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus it must only be called + but instead extracts cached return values. Thus, it must only be called from within obj.call_unprocessed and obj.cached_base_ret must be populated. Parameters @@ -81,11 +82,11 @@ def _device_fun_grad(obj: "JaxObjective", x: jnp.array): """ return jax.pure_callback( obj.cached_grad, - x, - result_shape=jax.ShapeDtypeStruct( + jax.ShapeDtypeStruct( obj.cached_base_ret[GRAD].shape, # bootstrap from cached value - np.float64, + x.dtype, ), + (x,), ) @@ -93,7 +94,7 @@ def _device_fun_hess(obj: "JaxObjective", x: jnp.array): """Jax compatible objective Hessian execution using host callback. This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus it must only be called + but instead extracts cached return values. Thus, it must only be called from within obj.call_unprocessed and obj.cached_base_ret must be populated. Parameters @@ -111,11 +112,11 @@ def _device_fun_hess(obj: "JaxObjective", x: jnp.array): """ return jax.pure_callback( obj.cached_hess, - x, - result_shape=jax.ShapeDtypeStruct( + jax.ShapeDtypeStruct( obj.cached_base_ret[HESS].shape, # bootstrap from cached value - np.float64, + x.dtype, ), + (x,), ) From 9811a713e295c799fbba1393c022b5b51bd4679f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 28 Mar 2024 10:39:12 +0000 Subject: [PATCH 03/11] run test in 32bit and 64bit mode --- test/base/test_objective.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/base/test_objective.py b/test/base/test_objective.py index 73615be32..b4fb6ada5 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -214,11 +214,14 @@ def test_aesara(max_sensi_order, integrated): ) -def test_jax(max_sensi_order, integrated): +@pytest.mark.parametrize("enable_x64", [True, False]) +def test_jax(max_sensi_order, integrated, enable_x64): """Test function composition and gradient computation via jax""" import jax import jax.numpy as jnp + jax.config.update("jax_enable_x64", enable_x64) + from pypesto.objective.jax import JaxObjective prob = rosen_for_sensi(max_sensi_order, integrated, [0, 1]) From a6537b1b4d755b6f6c8bd8e22886841072c8eeaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 28 Mar 2024 14:50:30 +0000 Subject: [PATCH 04/11] add vmap compatibility, remove second order support --- pypesto/objective/jax/base.py | 189 ++++++++++++++-------------------- test/base/test_objective.py | 52 ++++++---- 2 files changed, 110 insertions(+), 131 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index edd255902..f88989fa2 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -8,17 +8,17 @@ import copy from functools import partial -from typing import Callable, Sequence, Tuple +from typing import Callable, Sequence, Tuple, Union import numpy as np -from ...C import FVAL, GRAD, HESS, MODE_FUN, RDATAS, ModeType +from ...C import MODE_FUN, ModeType from ..base import ObjectiveBase, ResultDict try: import jax import jax.numpy as jnp - from jax import custom_jvp, grad + from jax import custom_jvp except ImportError: raise ImportError( "Using a jax objective requires an installation of " @@ -53,14 +53,13 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): by signature inspection in the underlying bind call. """ return jax.pure_callback( - obj.cached_fval, + partial(obj.base_objective, sensi_orders=(0,)), jax.ShapeDtypeStruct((), x.dtype), - (x,), + x, ) -@partial(custom_jvp, nondiff_argnums=(0,)) -def _device_fun_grad(obj: "JaxObjective", x: jnp.array): +def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): """Jax compatible objective gradient execution using host callback. This function does not actually call the underlying objective function, @@ -81,42 +80,21 @@ def _device_fun_grad(obj: "JaxObjective", x: jnp.array): by signature inspection in the underlying bind call. """ return jax.pure_callback( - obj.cached_grad, - jax.ShapeDtypeStruct( - obj.cached_base_ret[GRAD].shape, # bootstrap from cached value - x.dtype, + partial( + obj.base_objective, + sensi_orders=( + 0, + 1, + ), ), - (x,), - ) - - -def _device_fun_hess(obj: "JaxObjective", x: jnp.array): - """Jax compatible objective Hessian execution using host callback. - - This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus, it must only be called - from within obj.call_unprocessed and obj.cached_base_ret must be populated. - - Parameters - ---------- - obj: - The wrapped jax objective. - x: - jax computed input array. - - Note - ---- - This function should rather be implemented as class method of JaxObjective, - but this is not possible at the time of writing as this is not supported - by signature inspection in the underlying bind call. - """ - return jax.pure_callback( - obj.cached_hess, - jax.ShapeDtypeStruct( - obj.cached_base_ret[HESS].shape, # bootstrap from cached value - x.dtype, + ( + jax.ShapeDtypeStruct((), x.dtype), + jax.ShapeDtypeStruct( + x.shape, # bootstrap from cached value + x.dtype, + ), ), - (x,), + x, ) @@ -131,17 +109,8 @@ def _device_fun_jvp( """JVP implementation for device_fun.""" (x,) = primals (x_dot,) = tangents - return _device_fun(obj, x), _device_fun_grad(obj, x).dot(x_dot) - - -@_device_fun_grad.defjvp -def _device_fun_grad_jvp( - obj: "JaxObjective", primals: jnp.array, tangents: jnp.array -): - """JVP implementation for device_fun_grad.""" - (x,) = primals - (x_dot,) = tangents - return _device_fun_grad(obj, x), _device_fun_hess(obj, x).dot(x_dot) + value, grad = _device_fun_value_and_grad(obj, x) + return value, grad @ x_dot class JaxObjective(ObjectiveBase): @@ -169,7 +138,11 @@ def __init__( raise NotImplementedError( f"objective must support mode={MODE_FUN}" ) - super().__init__(x_names) + # store names directly rather than calling __init__ of super class + # as we can't initialize history as we are exposing the history of the + # inner objective + self._x_names = x_names + self.base_objective = objective self.jax_fun = jax_fun @@ -182,28 +155,7 @@ def jax_objective(x): y = jax_fun(x) return _device_fun(self, y) - # jit objective & derivatives (not integrated) - self.jax_objective = jax.jit(jax_objective) - self.jax_objective_grad = jax.jit(grad(jax_objective)) - self.jax_objective_hess = jax.jit(jax.hessian(jax_objective)) - - # jit input function - self.infun = jax.jit(self.jax_fun) - - # temporary storage for evaluation results of objective - self.cached_base_ret: ResultDict = {} - - def cached_fval(self, _): - """Return cached function value.""" - return self.cached_base_ret[FVAL] - - def cached_grad(self, _): - """Return cached gradient.""" - return self.cached_base_ret[GRAD] - - def cached_hess(self, _): - """Return cached Hessian.""" - return self.cached_base_ret[HESS] + self.jax_objective = jax_objective def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" @@ -214,7 +166,44 @@ def check_sensi_orders(self, sensi_orders, mode: ModeType) -> bool: if not self.check_mode(mode): return False else: - return self.base_objective.check_sensi_orders(sensi_orders, mode) + return ( + self.base_objective.check_sensi_orders(sensi_orders, mode) + and max(sensi_orders) == 0 + ) + + def __call__( + self, + x: jnp.ndarray, + sensi_orders: Tuple[int, ...] = (0,), + mode: ModeType = MODE_FUN, + return_dict: bool = False, + **kwargs, + ) -> Union[jnp.ndarray, Tuple, ResultDict]: + """ + See `ObjectiveBase` for more documentation. + + Note that this function delegates pre- and post-processing as well as + history handling to the inner objective. + """ + + if not self.check_mode(mode): + raise ValueError( + f"This Objective cannot be called with mode" f"={mode}." + ) + if not self.check_sensi_orders(sensi_orders, mode): + raise ValueError( + f"This Objective cannot be called with " + f"sensi_orders= {sensi_orders} and mode={mode}." + ) + + # this computes all the results from the inner objective, rendering + # them accessible as cached values for device_fun, etc. + if kwargs.pop("return_dict", False): + raise ValueError( + "return_dict=True is not available for JaxObjective evaluation" + ) + + return self.jax_objective(x) def call_unprocessed( self, @@ -226,40 +215,11 @@ def call_unprocessed( """ See `ObjectiveBase` for more documentation. - Main method to overwrite from the base class. It handles and - delegates the actual objective evaluation. + This function is not implemented for JaxObjective as it is not called + in the override for __call__. However, it's marked as abstract so we + need to implement it. """ - # derivative computation in jax always requires lower order - # derivatives, see jvp rules for device_fun and device_fun_grad. - if 2 in sensi_orders: - sensi_orders = (0, 1, 2) - elif 1 in sensi_orders: - sensi_orders = (0, 1) - else: - sensi_orders = (0,) - - # this computes all the results from the inner objective, rendering - # them accessible as cached values for device_fun, etc. - set_return_dict, return_dict = ( - "return_dict" in kwargs, - kwargs.pop("return_dict", False), - ) - self.cached_base_ret = self.base_objective( - self.infun(x), sensi_orders, mode, return_dict=True, **kwargs - ) - if set_return_dict: - kwargs["return_dict"] = return_dict - ret = {} - if RDATAS in self.cached_base_ret: - ret[RDATAS] = self.cached_base_ret[RDATAS] - if 0 in sensi_orders: - ret[FVAL] = float(self.jax_objective(x)) - if 1 in sensi_orders: - ret[GRAD] = self.jax_objective_grad(x) - if 2 in sensi_orders: - ret[HESS] = self.jax_objective_hess(x) - - return ret + pass def __deepcopy__(self, memodict=None): other = JaxObjective( @@ -267,5 +227,14 @@ def __deepcopy__(self, memodict=None): copy.deepcopy(self.jax_fun), copy.deepcopy(self.x_names), ) - return other + + @property + def history(self): + """Exposes the history of the inner objective.""" + return self.base_objective.history + + @property + def pre_post_processor(self): + """Exposes the pre_post_processor of inner objective.""" + return self.base_objective.pre_post_processor diff --git a/test/base/test_objective.py b/test/base/test_objective.py index b4fb6ada5..74843f8e2 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -220,6 +220,9 @@ def test_jax(max_sensi_order, integrated, enable_x64): import jax import jax.numpy as jnp + if max_sensi_order == 2: + pytest.skip("Not Implemented") + jax.config.update("jax_enable_x64", enable_x64) from pypesto.objective.jax import JaxObjective @@ -227,32 +230,39 @@ def test_jax(max_sensi_order, integrated, enable_x64): prob = rosen_for_sensi(max_sensi_order, integrated, [0, 1]) # apply inverse transform such that we evaluate at prob['x'] - x_ref = np.arcsinh(prob["x"]) + x_ref = np.asarray(prob["x"]) / 2 - def jac_op(x: jnp.array) -> jnp.array: - return jax.lax.sinh(x) + def jax_op(x: jnp.array) -> jnp.array: + # pick a simple function here to avoid numerical issues + return 2.0 * x # compose rosenbrock function with sinh transformation - obj = JaxObjective(prob["obj"], jac_op) + obj = JaxObjective(prob["obj"], jax_op) - # check function values and derivatives, also after copy + # evaluate for a couple of random points such that we can assess + # compatibility with vmap + xx = x_ref + np.random.randn(10, x_ref.shape[0]) + rvals_ref = [ + prob["obj"](jax_op(xxi), sensi_orders=(max_sensi_order,)) for xxi in xx + ] for _obj in (obj, copy.deepcopy(obj)): - # function value - assert _obj(x_ref) == prob["fval"] - - # gradient - if max_sensi_order > 0: - assert np.allclose( - _obj(x_ref, sensi_orders=(1,)), prob["grad"] * np.cosh(x_ref) - ) - - # hessian - if max_sensi_order > 1: - assert np.allclose( - prob["hess"] * (np.diag(np.power(np.cosh(x_ref), 2))) - + np.diag(prob["grad"] * np.sinh(x_ref)), - _obj(x_ref, sensi_orders=(2,)), - ) + jaxfun = _obj + if max_sensi_order == 1: + jaxfun = jax.grad(jaxfun) + # check compatibility with vmap and jit + vmapped = jax.vmap(jaxfun) + rvals_jax = vmapped(xx) + atol = 0 + # also need to account for roundoff errors in input, so we are not + # we can't use rtol = 1e-8 for 32bit + rtol = 1e-16 if enable_x64 else 1e-6 + for x, rref, rj in zip(xx, rvals_ref, rvals_jax): + if max_sensi_order == 0: + np.testing.assert_allclose(rref, rj, atol=atol, rtol=rtol) + if max_sensi_order == 1: + np.testing.assert_allclose( + rref @ jax.jacfwd(jax_op)(x), rj, atol=atol, rtol=rtol + ) @pytest.fixture( From cd0b7271576cf51f836ae0dab014438152ca9935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 28 Mar 2024 15:34:47 +0000 Subject: [PATCH 05/11] Update test_objective.py --- test/base/test_objective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_objective.py b/test/base/test_objective.py index 74843f8e2..de2879b56 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -255,7 +255,7 @@ def jax_op(x: jnp.array) -> jnp.array: atol = 0 # also need to account for roundoff errors in input, so we are not # we can't use rtol = 1e-8 for 32bit - rtol = 1e-16 if enable_x64 else 1e-6 + rtol = 1e-16 if enable_x64 else 1e-5 for x, rref, rj in zip(xx, rvals_ref, rvals_jax): if max_sensi_order == 0: np.testing.assert_allclose(rref, rj, atol=atol, rtol=rtol) From 36fc8990fe5f53a6ee2a05877b09a2ea057459c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 29 Mar 2024 11:43:22 +0000 Subject: [PATCH 06/11] Apply suggestions from code review Co-authored-by: Paul Jonas Jost <70631928+PaulJonasJost@users.noreply.github.com> --- pypesto/objective/jax/base.py | 4 ++-- test/base/test_objective.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index f88989fa2..e527ec57f 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -180,7 +180,7 @@ def __call__( **kwargs, ) -> Union[jnp.ndarray, Tuple, ResultDict]: """ - See `ObjectiveBase` for more documentation. + See :class:`ObjectiveBase` for more documentation. Note that this function delegates pre- and post-processing as well as history handling to the inner objective. @@ -213,7 +213,7 @@ def call_unprocessed( **kwargs, ) -> ResultDict: """ - See `ObjectiveBase` for more documentation. + See :class:`ObjectiveBase` for more documentation. This function is not implemented for JaxObjective as it is not called in the override for __call__. However, it's marked as abstract so we diff --git a/test/base/test_objective.py b/test/base/test_objective.py index de2879b56..56bd2c22e 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -253,8 +253,8 @@ def jax_op(x: jnp.array) -> jnp.array: vmapped = jax.vmap(jaxfun) rvals_jax = vmapped(xx) atol = 0 - # also need to account for roundoff errors in input, so we are not - # we can't use rtol = 1e-8 for 32bit + # also need to account for roundoff errors in input, so we + # can't use rtol = 1e-8 for 32bit rtol = 1e-16 if enable_x64 else 1e-5 for x, rref, rj in zip(xx, rvals_ref, rvals_jax): if max_sensi_order == 0: From f1183de18531ce0c389a4ec8e3342f33b46d1493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 29 Mar 2024 12:30:28 +0000 Subject: [PATCH 07/11] alter implementation, fix doc and extend tests --- pypesto/objective/jax/base.py | 60 ++++++++++++----------------------- test/base/test_objective.py | 49 +++++++++++++++++++++------- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index f88989fa2..fdaec90bf 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -8,7 +8,7 @@ import copy from functools import partial -from typing import Callable, Sequence, Tuple, Union +from typing import Sequence, Tuple, Union import numpy as np @@ -28,16 +28,11 @@ # jax compatible (jit-able) objective function using external callback, see # https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -# note that these functions are impure since they rely on cached values @partial(custom_jvp, nondiff_argnums=(0,)) def _device_fun(obj: "JaxObjective", x: jnp.array): - """Jax compatible objective function execution using host callback. - - This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus, it must only be called - from within obj.call_unprocessed, and obj.cached_base_ret must be populated. + """Jax compatible objective function execution using external callback. Parameters ---------- @@ -45,12 +40,6 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): The wrapped jax objective. x: jax computed input array. - - Note - ---- - This function should rather be implemented as class method of JaxObjective, - but this is not possible at the time of writing as this is not supported - by signature inspection in the underlying bind call. """ return jax.pure_callback( partial(obj.base_objective, sensi_orders=(0,)), @@ -60,11 +49,13 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): - """Jax compatible objective gradient execution using host callback. + """Jax compatible objective gradient execution using external callback. - This function does not actually call the underlying objective function, - but instead extracts cached return values. Thus, it must only be called - from within obj.call_unprocessed and obj.cached_base_ret must be populated. + This function will be called when computing the gradient of the + `JaxObjective` using `jax.grad` or `jax.value_and_grad`. In the latter + case, the function will return both the function value and the gradient, + so no caching is necessary. For higher order derivatives, caching would + be advantageous, but unclear how to implement this. Parameters ---------- @@ -72,12 +63,6 @@ def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): The wrapped jax objective. x: jax computed input array. - - Note - ---- - This function should rather be implemented as class method of JaxObjective, - but this is not possible at the time of writing as this is not supported - by signature inspection in the underlying bind call. """ return jax.pure_callback( partial( @@ -98,7 +83,7 @@ def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): ) -# define custom jvp for device_fun & device_fun_grad to enable autodiff, see +# define custom jvp for device_fun to enable autodiff, see # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html @@ -114,22 +99,26 @@ def _device_fun_jvp( class JaxObjective(ObjectiveBase): - """Objective function that combines pypesto objectives with jax functions. + """Objective function that enables use of pypesto objectives in jax models. - The generated objective function will evaluate objective(jax_fun(x)). + The generated function should generally be compatible with jax, but cannot + compute higher order derivatives and is not vectorized (but still + compatible with jax.vmap) Parameters ---------- objective: - pyPESTO objective - jax_fun: - jax function (not jitted) that computes input to the pyPESTO objective + pyPESTO objective to be wrapped. + + Note + ---- + Currently only implements MODE_FUN and sensi_orders=(0,). Support for + MODE_RES should be straightforward to add. """ def __init__( self, objective: ObjectiveBase, - jax_fun: Callable, x_names: Sequence[str] = None, ): if not isinstance(objective, ObjectiveBase): @@ -145,17 +134,9 @@ def __init__( self.base_objective = objective - self.jax_fun = jax_fun - # would be cleaner to also have this as class method, but not supported # by signature inspection in bind call. - def jax_objective(x): - # device fun doesn't actually need the value of y, but we need to - # compute this here for autodiff to work - y = jax_fun(x) - return _device_fun(self, y) - - self.jax_objective = jax_objective + self.jax_objective = partial(_device_fun, self) def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" @@ -224,7 +205,6 @@ def call_unprocessed( def __deepcopy__(self, memodict=None): other = JaxObjective( copy.deepcopy(self.base_objective), - copy.deepcopy(self.jax_fun), copy.deepcopy(self.x_names), ) return other diff --git a/test/base/test_objective.py b/test/base/test_objective.py index de2879b56..66ed1b711 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -2,6 +2,7 @@ import copy import numbers +from functools import partial import numpy as np import pytest @@ -232,26 +233,43 @@ def test_jax(max_sensi_order, integrated, enable_x64): # apply inverse transform such that we evaluate at prob['x'] x_ref = np.asarray(prob["x"]) / 2 - def jax_op(x: jnp.array) -> jnp.array: + def jax_op_in(x: jnp.array) -> jnp.array: # pick a simple function here to avoid numerical issues - return 2.0 * x + return 3.0 * x + + def jax_op_out(x: jnp.array) -> jnp.array: + # pick a simple function here to avoid numerical issues + return 0.5 * x # compose rosenbrock function with sinh transformation - obj = JaxObjective(prob["obj"], jax_op) + obj = JaxObjective(prob["obj"]) # evaluate for a couple of random points such that we can assess # compatibility with vmap xx = x_ref + np.random.randn(10, x_ref.shape[0]) rvals_ref = [ - prob["obj"](jax_op(xxi), sensi_orders=(max_sensi_order,)) for xxi in xx + jax_op_out( + prob["obj"](jax_op_in(xxi), sensi_orders=(max_sensi_order,)) + ) + for xxi in xx ] + + def _fun(y, pypesto_fun, jax_fun_in, jax_fun_out): + return jax_fun_out(pypesto_fun(jax_fun_in(y))) + for _obj in (obj, copy.deepcopy(obj)): - jaxfun = _obj + fun = partial( + _fun, + pypesto_fun=_obj, + jax_fun_in=jax_op_in, + jax_fun_out=jax_op_out, + ) + if max_sensi_order == 1: - jaxfun = jax.grad(jaxfun) + fun = jax.grad(fun) # check compatibility with vmap and jit - vmapped = jax.vmap(jaxfun) - rvals_jax = vmapped(xx) + vmapped_fun = jax.vmap(fun) + rvals_jax = vmapped_fun(xx) atol = 0 # also need to account for roundoff errors in input, so we are not # we can't use rtol = 1e-8 for 32bit @@ -260,9 +278,18 @@ def jax_op(x: jnp.array) -> jnp.array: if max_sensi_order == 0: np.testing.assert_allclose(rref, rj, atol=atol, rtol=rtol) if max_sensi_order == 1: - np.testing.assert_allclose( - rref @ jax.jacfwd(jax_op)(x), rj, atol=atol, rtol=rtol - ) + # g(x) = b(c(x)) => g'(x) = b'(c(x))) * c'(x) + # f(x) = a(g(x)) => f'(x) = a'(g(x)) * g'(x) + # c: jax_op_in, b: prob["obj"], a: jax_op_out + # g(x) = b(c(x)) + g = prob["obj"](jax_op_in(x)) + # g'(x) = b'(c(x))) * c'(x) + g_prime = prob["obj"]( + jax_op_in(x), sensi_orders=(1,) + ) @ jax.jacfwd(jax_op_in)(x) + # f'(x) = a'(g(x)) * g'(x) + f_prime = jax.jacfwd(jax_op_out)(g) * g_prime + np.testing.assert_allclose(f_prime, rj, atol=atol, rtol=rtol) @pytest.fixture( From 78e9eb1d39bc10c596ee8f2b4fd0d2f9bff04850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 29 Mar 2024 12:36:18 +0000 Subject: [PATCH 08/11] Update test_objective.py --- test/base/test_objective.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/base/test_objective.py b/test/base/test_objective.py index dde27ecb3..29994f0c8 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -230,8 +230,7 @@ def test_jax(max_sensi_order, integrated, enable_x64): prob = rosen_for_sensi(max_sensi_order, integrated, [0, 1]) - # apply inverse transform such that we evaluate at prob['x'] - x_ref = np.asarray(prob["x"]) / 2 + x_ref = np.asarray(prob["x"]) def jax_op_in(x: jnp.array) -> jnp.array: # pick a simple function here to avoid numerical issues From 34ff6c47266ec3d7e86f52da691c0b2afc087b3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 3 Apr 2024 21:21:26 +0100 Subject: [PATCH 09/11] expose inner x_names rather than allowing them to be set --- pypesto/objective/jax/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index 8022c4c30..e0d741c1c 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -8,7 +8,7 @@ import copy from functools import partial -from typing import Sequence, Tuple, Union +from typing import Tuple, Union import numpy as np @@ -119,7 +119,6 @@ class JaxObjective(ObjectiveBase): def __init__( self, objective: ObjectiveBase, - x_names: Sequence[str] = None, ): if not isinstance(objective, ObjectiveBase): raise TypeError("objective must be an ObjectiveBase instance") @@ -127,11 +126,6 @@ def __init__( raise NotImplementedError( f"objective must support mode={MODE_FUN}" ) - # store names directly rather than calling __init__ of super class - # as we can't initialize history as we are exposing the history of the - # inner objective - self._x_names = x_names - self.base_objective = objective # would be cleaner to also have this as class method, but not supported @@ -205,7 +199,6 @@ def call_unprocessed( def __deepcopy__(self, memodict=None): other = JaxObjective( copy.deepcopy(self.base_objective), - copy.deepcopy(self.x_names), ) return other @@ -218,3 +211,8 @@ def history(self): def pre_post_processor(self): """Exposes the pre_post_processor of inner objective.""" return self.base_objective.pre_post_processor + + @property + def x_names(self): + """Exposes the x_names of inner objective.""" + return self.base_objective.x_names From 4e420b4e22a3a26e56f6e72c44fd865d760d2cba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 5 Apr 2024 11:09:06 +0100 Subject: [PATCH 10/11] Update base.py --- pypesto/objective/jax/base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index aa79eef78..49327bb37 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -8,7 +8,7 @@ import copy from functools import partial -from typing import Tuple, Union +from typing import Union import numpy as np @@ -31,7 +31,7 @@ @partial(custom_jvp, nondiff_argnums=(0,)) -def _device_fun(obj: "JaxObjective", x: jnp.array): +def _device_fun(base_objective: ObjectiveBase, x: jnp.array): """Jax compatible objective function execution using external callback. Parameters @@ -42,13 +42,13 @@ def _device_fun(obj: "JaxObjective", x: jnp.array): jax computed input array. """ return jax.pure_callback( - partial(obj.base_objective, sensi_orders=(0,)), + partial(base_objective, sensi_orders=(0,)), jax.ShapeDtypeStruct((), x.dtype), x, ) -def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): +def _device_fun_value_and_grad(base_objective: ObjectiveBase, x: jnp.array): """Jax compatible objective gradient execution using external callback. This function will be called when computing the gradient of the @@ -66,7 +66,7 @@ def _device_fun_value_and_grad(obj: "JaxObjective", x: jnp.array): """ return jax.pure_callback( partial( - obj.base_objective, + base_objective, sensi_orders=( 0, 1, @@ -130,7 +130,7 @@ def __init__( # would be cleaner to also have this as class method, but not supported # by signature inspection in bind call. - self.jax_objective = partial(_device_fun, self) + self.jax_objective = partial(_device_fun, self.base_objective) def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" @@ -149,11 +149,11 @@ def check_sensi_orders(self, sensi_orders, mode: ModeType) -> bool: def __call__( self, x: jnp.ndarray, - sensi_orders: Tuple[int, ...] = (0,), + sensi_orders: tuple[int, ...] = (0,), mode: ModeType = MODE_FUN, return_dict: bool = False, **kwargs, - ) -> Union[jnp.ndarray, Tuple, ResultDict]: + ) -> Union[jnp.ndarray, tuple, ResultDict]: """ See :class:`ObjectiveBase` for more documentation. From 180545d6d79e273b44988d6972ebd241fb50acd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 5 Apr 2024 16:07:32 +0100 Subject: [PATCH 11/11] Update test_objective.py --- test/base/test_objective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_objective.py b/test/base/test_objective.py index 29994f0c8..d2c556927 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -272,7 +272,7 @@ def _fun(y, pypesto_fun, jax_fun_in, jax_fun_out): atol = 0 # also need to account for roundoff errors in input, so we # can't use rtol = 1e-8 for 32bit - rtol = 1e-16 if enable_x64 else 1e-5 + rtol = 1e-16 if enable_x64 else 1e-4 for x, rref, rj in zip(xx, rvals_ref, rvals_jax): if max_sensi_order == 0: np.testing.assert_allclose(rref, rj, atol=atol, rtol=rtol)