diff --git a/pypesto/objective/aesara/base.py b/pypesto/objective/aesara/base.py index dc2528f10..ce7ad8c22 100644 --- a/pypesto/objective/aesara/base.py +++ b/pypesto/objective/aesara/base.py @@ -110,6 +110,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/aggregated.py b/pypesto/objective/aggregated.py index 21e1844b0..13b236816 100644 --- a/pypesto/objective/aggregated.py +++ b/pypesto/objective/aggregated.py @@ -1,3 +1,5 @@ +import inspect +import warnings from collections.abc import Sequence from copy import deepcopy from typing import Any @@ -92,6 +94,7 @@ def call_unprocessed( sensi_orders: tuple[int, ...], mode: ModeType, kwargs_list: Sequence[dict[str, Any]] = None, + return_dict: bool = False, **kwargs, ) -> ResultDict: """ @@ -113,6 +116,20 @@ def call_unprocessed( "The length of `kwargs_list` must match the number of " "objectives you are aggregating." ) + for objective_, objective_kwargs in zip(self._objectives, kwargs_list): + if ( + "return_dict" + in inspect.signature(objective_.call_unprocessed).parameters + ): + objective_kwargs["return_dict"] = return_dict + else: + warnings.warn( + "Please add `return_dict` to the argument list of your " + "objective's `call_unprocessed` method. " + f"Current objective: `{type(objective_)}`.", + DeprecationWarning, + stacklevel=1, + ) return aggregate_results( [ objective.call_unprocessed( diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index 9c7035329..b69fcfa6e 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -26,7 +26,7 @@ HistoryTypeError, MemoryHistory, ) -from ..base import ObjectiveBase, ResultDict +from ..base import ObjectiveBase from .amici_calculator import AmiciCalculator from .amici_util import ( create_identity_parameter_mapping, @@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]: class AmiciObjective(ObjectiveBase): """Allows to create an objective directly from an amici model.""" + share_return_dict = True + def __init__( self, amici_model: AmiciModel, @@ -415,33 +417,12 @@ def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" return mode in [MODE_FUN, MODE_RES] - def __call__( - self, - x: np.ndarray, - sensi_orders: tuple[int, ...] = (0,), - mode: ModeType = MODE_FUN, - return_dict: bool = False, - **kwargs, - ) -> Union[float, np.ndarray, tuple, ResultDict]: - """See `ObjectiveBase` documentation.""" - import amici - - # Use AMICI full reporting if amici.ReturnDatas are returned and no - # other reporting mode was set - if ( - return_dict - and self.amici_reporting is None - and "amici_reporting" not in kwargs - ): - kwargs["amici_reporting"] = amici.RDataReporting.full - - return super().__call__(x, sensi_orders, mode, return_dict, **kwargs) - def call_unprocessed( self, x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool = False, edatas: Sequence["amici.ExpData"] = None, parameter_mapping: "ParameterMapping" = None, amici_reporting: Optional["amici.RDataReporting"] = None, @@ -458,18 +439,23 @@ def call_unprocessed( x_dct = self.par_arr_to_dct(x) - # only ask amici to compute required quantities amici_reporting = ( self.amici_reporting if amici_reporting is None else amici_reporting ) if amici_reporting is None: - amici_reporting = ( - amici.RDataReporting.likelihood - if mode == MODE_FUN - else amici.RDataReporting.residuals - ) + if return_dict: + # Use AMICI full reporting if amici.ReturnDatas are returned + # and no other reporting mode was set + amici_reporting = amici.RDataReporting.full + else: + # Else, only ask amici to compute required quantities + amici_reporting = ( + amici.RDataReporting.likelihood + if mode == MODE_FUN + else amici.RDataReporting.residuals + ) self.amici_solver.setReturnDataReportingMode(amici_reporting) # update steady state diff --git a/pypesto/objective/base.py b/pypesto/objective/base.py index e5212b31c..438833aac 100644 --- a/pypesto/objective/base.py +++ b/pypesto/objective/base.py @@ -1,5 +1,7 @@ import copy +import inspect import logging +import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from typing import Optional, Union @@ -178,6 +180,19 @@ def __call__( x_full = self.pre_post_processor.preprocess(x=x) # compute result + if ( + "return_dict" + in inspect.signature(self.call_unprocessed).parameters + ): + kwargs["return_dict"] = return_dict + else: + warnings.warn( + "Please add `return_dict` to the argument list of your " + "objective's `call_unprocessed` method.", + DeprecationWarning, + stacklevel=1, + ) + result = self.call_unprocessed( x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs ) @@ -204,6 +219,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ @@ -217,6 +233,10 @@ def call_unprocessed( Specifies which sensitivities to compute, e.g. (0,1) -> fval, grad. mode: Whether to compute function values or residuals. + return_dict: + Whether the user requested additional information. Objectives can + use this to determine whether to e.g. return "full" or "minimal" + information. Returns ------- diff --git a/pypesto/objective/finite_difference.py b/pypesto/objective/finite_difference.py index 78d997a57..4bf7a505f 100644 --- a/pypesto/objective/finite_difference.py +++ b/pypesto/objective/finite_difference.py @@ -373,6 +373,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ @@ -383,11 +384,17 @@ def call_unprocessed( """ if mode == MODE_FUN: result = self._call_mode_fun( - x=x, sensi_orders=sensi_orders, **kwargs + x=x, + sensi_orders=sensi_orders, + return_dict=return_dict, + **kwargs, ) elif mode == MODE_RES: result = self._call_mode_res( - x=x, sensi_orders=sensi_orders, **kwargs + x=x, + sensi_orders=sensi_orders, + return_dict=return_dict, + **kwargs, ) else: raise ValueError("This mode is not supported.") @@ -398,6 +405,7 @@ def _call_mode_fun( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> ResultDict: """Handle calls in function value mode. @@ -408,6 +416,7 @@ def _call_mode_fun( sensi_orders_obj, result = self._call_from_obj_fun( x=x, sensi_orders=sensi_orders, + return_dict=return_dict, **kwargs, ) @@ -429,13 +438,21 @@ def _call_mode_fun( def f_fval(x): """Short-hand to get a function value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(0,), mode=MODE_FUN, **kwargs + x=x, + sensi_orders=(0,), + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, )[FVAL] def f_grad(x): """Short-hand to get a gradient value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(1,), mode=MODE_FUN, **kwargs + x=x, + sensi_orders=(1,), + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, )[GRAD] # update delta vectors @@ -487,6 +504,7 @@ def _call_mode_res( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> ResultDict: """Handle calls in residual mode. @@ -497,6 +515,7 @@ def _call_mode_res( sensi_orders_obj, result = self._call_from_obj_res( x=x, sensi_orders=sensi_orders, + return_dict=return_dict, **kwargs, ) @@ -507,7 +526,11 @@ def _call_mode_res( def f_res(x): """Short-hand to get a function value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(0,), mode=MODE_RES, **kwargs + x=x, + sensi_orders=(0,), + mode=MODE_RES, + return_dict=return_dict, + **kwargs, )[RES] # update delta vector @@ -532,6 +555,7 @@ def _call_from_obj_fun( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> tuple[tuple[int, ...], ResultDict]: """ @@ -553,7 +577,11 @@ def _call_from_obj_fun( result = {} if sensi_orders_obj: result = self.obj.call_unprocessed( - x=x, sensi_orders=sensi_orders_obj, mode=MODE_FUN, **kwargs + x=x, + sensi_orders=sensi_orders_obj, + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, ) return sensi_orders_obj, result @@ -561,6 +589,7 @@ def _call_from_obj_res( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> tuple[tuple[int, ...], ResultDict]: """ @@ -580,7 +609,11 @@ def _call_from_obj_res( result = {} if sensi_orders_obj: result = self.obj.call_unprocessed( - x=x, sensi_orders=sensi_orders_obj, mode=MODE_RES, **kwargs + x=x, + sensi_orders=sensi_orders_obj, + mode=MODE_RES, + return_dict=return_dict, + **kwargs, ) return sensi_orders_obj, result diff --git a/pypesto/objective/function.py b/pypesto/objective/function.py index 696da36a8..1fce9549e 100644 --- a/pypesto/objective/function.py +++ b/pypesto/objective/function.py @@ -134,6 +134,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index ea86b07a0..14204f32d 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -215,6 +215,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/priors.py b/pypesto/objective/priors.py index 3b8f177fe..e460c1572 100644 --- a/pypesto/objective/priors.py +++ b/pypesto/objective/priors.py @@ -72,6 +72,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: C.ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/roadrunner/road_runner.py b/pypesto/objective/roadrunner/road_runner.py index 635d2b6e5..c2f88de80 100644 --- a/pypesto/objective/roadrunner/road_runner.py +++ b/pypesto/objective/roadrunner/road_runner.py @@ -96,6 +96,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, edatas: Optional[Sequence[ExpData]] = None, parameter_mapping: Optional[list[ParMappingDictQuadruple]] = None, ) -> dict: diff --git a/test/base/test_objective.py b/test/base/test_objective.py index a9d2e3151..b5d3ee930 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -447,3 +447,29 @@ def test_fds(fd_method, fd_delta): assert obj_fd.delta_fun.updates == 0 else: assert obj_fd.delta_fun.updates > 1 + + +def test_call_unprocessed_return_dict(): + class Objective0(pypesto.objective.ObjectiveBase): + def call_unprocessed(self, *args, **kwargs): + return {"fval": 0, "return_dict": "return_dict" in kwargs} + + def check_sensi_orders(self, *args, **kwargs): + return True + + class Objective1(Objective0): + def call_unprocessed(self, *args, return_dict: bool, **kwargs): + return {"fval": 0, "return_dict": return_dict} + + objective0 = Objective0() + objective1 = Objective1() + + with pytest.warns(DeprecationWarning, match="Please add `return_dict`"): + result0 = objective0([0], return_dict=True) + + result1 = objective1([0], return_dict=True) + + # `return_dict` is not shared with `call_unprocessed` by default, + assert not result0["return_dict"] + # but is shared if the `call_unprocessed` signature supports it. + assert result1["return_dict"]