From 5860c0febfdf2547d4eae93649eadab0ca201518 Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 00:01:41 +0200 Subject: [PATCH 1/6] add `ObjectiveBase.share_return_dict`; small refactor to fix `return_dict` with amici posterior --- pypesto/objective/aggregated.py | 6 +++++ pypesto/objective/amici/amici.py | 44 +++++++++++--------------------- pypesto/objective/base.py | 7 +++++ 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pypesto/objective/aggregated.py b/pypesto/objective/aggregated.py index 21e1844b0..dcfbce1f2 100644 --- a/pypesto/objective/aggregated.py +++ b/pypesto/objective/aggregated.py @@ -20,6 +20,8 @@ class AggregatedObjective(ObjectiveBase): """Aggregates multiple objectives into one objective.""" + share_return_dict = True + def __init__( self, objectives: Sequence[ObjectiveBase], @@ -92,6 +94,7 @@ def call_unprocessed( sensi_orders: tuple[int, ...], mode: ModeType, kwargs_list: Sequence[dict[str, Any]] = None, + return_dict: bool = False, **kwargs, ) -> ResultDict: """ @@ -113,6 +116,9 @@ def call_unprocessed( "The length of `kwargs_list` must match the number of " "objectives you are aggregating." ) + for objective_, objective_kwargs in zip(self._objectives, kwargs_list): + if objective_.share_return_dict: + objective_kwargs["return_dict"] = return_dict return aggregate_results( [ objective.call_unprocessed( diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index 9c7035329..00f975280 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -26,7 +26,7 @@ HistoryTypeError, MemoryHistory, ) -from ..base import ObjectiveBase, ResultDict +from ..base import ObjectiveBase from .amici_calculator import AmiciCalculator from .amici_util import ( create_identity_parameter_mapping, @@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]: class AmiciObjective(ObjectiveBase): """Allows to create an objective directly from an amici model.""" + share_return_dict = True + def __init__( self, amici_model: AmiciModel, @@ -415,28 +417,6 @@ def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" return mode in [MODE_FUN, MODE_RES] - def __call__( - self, - x: np.ndarray, - sensi_orders: tuple[int, ...] = (0,), - mode: ModeType = MODE_FUN, - return_dict: bool = False, - **kwargs, - ) -> Union[float, np.ndarray, tuple, ResultDict]: - """See `ObjectiveBase` documentation.""" - import amici - - # Use AMICI full reporting if amici.ReturnDatas are returned and no - # other reporting mode was set - if ( - return_dict - and self.amici_reporting is None - and "amici_reporting" not in kwargs - ): - kwargs["amici_reporting"] = amici.RDataReporting.full - - return super().__call__(x, sensi_orders, mode, return_dict, **kwargs) - def call_unprocessed( self, x: np.ndarray, @@ -445,6 +425,7 @@ def call_unprocessed( edatas: Sequence["amici.ExpData"] = None, parameter_mapping: "ParameterMapping" = None, amici_reporting: Optional["amici.RDataReporting"] = None, + return_dict: bool = False, ): """ Call objective function without pre- or post-processing and formatting. @@ -458,18 +439,23 @@ def call_unprocessed( x_dct = self.par_arr_to_dct(x) - # only ask amici to compute required quantities amici_reporting = ( self.amici_reporting if amici_reporting is None else amici_reporting ) if amici_reporting is None: - amici_reporting = ( - amici.RDataReporting.likelihood - if mode == MODE_FUN - else amici.RDataReporting.residuals - ) + if return_dict: + # Use AMICI full reporting if amici.ReturnDatas are returned + # and no other reporting mode was set + amici_reporting = amici.RDataReporting.full + else: + # Else, only ask amici to compute required quantities + amici_reporting = ( + amici.RDataReporting.likelihood + if mode == MODE_FUN + else amici.RDataReporting.residuals + ) self.amici_solver.setReturnDataReportingMode(amici_reporting) # update steady state diff --git a/pypesto/objective/base.py b/pypesto/objective/base.py index e5212b31c..b85ab7313 100644 --- a/pypesto/objective/base.py +++ b/pypesto/objective/base.py @@ -45,8 +45,13 @@ class ObjectiveBase(ABC): pre_post_processor: Preprocess input values to and postprocess output values from __call__. Configured in `update_from_problem()`. + share_return_dict: + Whether the objective uses `return_dict` in its `call_unprocessed` + method. """ + share_return_dict: bool = False + def __init__( self, x_names: Optional[Sequence[str]] = None, @@ -178,6 +183,8 @@ def __call__( x_full = self.pre_post_processor.preprocess(x=x) # compute result + if self.share_return_dict: + kwargs["return_dict"] = return_dict result = self.call_unprocessed( x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs ) From 27779ab7a5a3a4898efeb6e91d9c576658ee2820 Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 09:55:21 +0200 Subject: [PATCH 2/6] use user-supplied return_dict --- pypesto/objective/aggregated.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pypesto/objective/aggregated.py b/pypesto/objective/aggregated.py index dcfbce1f2..c39a211bd 100644 --- a/pypesto/objective/aggregated.py +++ b/pypesto/objective/aggregated.py @@ -118,7 +118,10 @@ def call_unprocessed( ) for objective_, objective_kwargs in zip(self._objectives, kwargs_list): if objective_.share_return_dict: - objective_kwargs["return_dict"] = return_dict + objective_kwargs["return_dict"] = objective_kwargs.get( + "return_dict", + return_dict, + ) return aggregate_results( [ objective.call_unprocessed( From 1ff18a974f528e0bc65cf42dce715dbedf403714 Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 10:13:55 +0200 Subject: [PATCH 3/6] test share_return_dict --- test/base/test_objective.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/base/test_objective.py b/test/base/test_objective.py index a9d2e3151..66b7928e6 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -447,3 +447,26 @@ def test_fds(fd_method, fd_delta): assert obj_fd.delta_fun.updates == 0 else: assert obj_fd.delta_fun.updates > 1 + + +def test_shared_return_dict(): + class Objective0(pypesto.objective.ObjectiveBase): + def call_unprocessed(self, *args, **kwargs): + return {"fval": 0, "return_dict": "return_dict" in kwargs} + + def check_sensi_orders(self, *args, **kwargs): + return True + + class Objective1(Objective0): + share_return_dict = True + + objective0 = Objective0() + objective1 = Objective1() + + result0 = objective0([0], return_dict=True) + result1 = objective1([0], return_dict=True) + + # `return_dict` is not shared with `call_unprocessed` by default, + assert not result0["return_dict"] + # but `ObjectiveBase.shared_return_dict = True` changes that. + assert result1["return_dict"] From 5a592e1b69de5b354cda6a13e57c18976752660a Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 10:47:24 +0200 Subject: [PATCH 4/6] make default with `DeprecationWarning` --- pypesto/objective/aggregated.py | 20 ++++++++++++++------ pypesto/objective/amici/amici.py | 2 +- pypesto/objective/base.py | 21 +++++++++++++++------ test/base/test_objective.py | 11 +++++++---- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/pypesto/objective/aggregated.py b/pypesto/objective/aggregated.py index c39a211bd..13b236816 100644 --- a/pypesto/objective/aggregated.py +++ b/pypesto/objective/aggregated.py @@ -1,3 +1,5 @@ +import inspect +import warnings from collections.abc import Sequence from copy import deepcopy from typing import Any @@ -20,8 +22,6 @@ class AggregatedObjective(ObjectiveBase): """Aggregates multiple objectives into one objective.""" - share_return_dict = True - def __init__( self, objectives: Sequence[ObjectiveBase], @@ -117,10 +117,18 @@ def call_unprocessed( "objectives you are aggregating." ) for objective_, objective_kwargs in zip(self._objectives, kwargs_list): - if objective_.share_return_dict: - objective_kwargs["return_dict"] = objective_kwargs.get( - "return_dict", - return_dict, + if ( + "return_dict" + in inspect.signature(objective_.call_unprocessed).parameters + ): + objective_kwargs["return_dict"] = return_dict + else: + warnings.warn( + "Please add `return_dict` to the argument list of your " + "objective's `call_unprocessed` method. " + f"Current objective: `{type(objective_)}`.", + DeprecationWarning, + stacklevel=1, ) return aggregate_results( [ diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index 00f975280..b69fcfa6e 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -422,10 +422,10 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool = False, edatas: Sequence["amici.ExpData"] = None, parameter_mapping: "ParameterMapping" = None, amici_reporting: Optional["amici.RDataReporting"] = None, - return_dict: bool = False, ): """ Call objective function without pre- or post-processing and formatting. diff --git a/pypesto/objective/base.py b/pypesto/objective/base.py index b85ab7313..a851411e1 100644 --- a/pypesto/objective/base.py +++ b/pypesto/objective/base.py @@ -1,5 +1,7 @@ import copy +import inspect import logging +import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from typing import Optional, Union @@ -45,13 +47,8 @@ class ObjectiveBase(ABC): pre_post_processor: Preprocess input values to and postprocess output values from __call__. Configured in `update_from_problem()`. - share_return_dict: - Whether the objective uses `return_dict` in its `call_unprocessed` - method. """ - share_return_dict: bool = False - def __init__( self, x_names: Optional[Sequence[str]] = None, @@ -183,8 +180,19 @@ def __call__( x_full = self.pre_post_processor.preprocess(x=x) # compute result - if self.share_return_dict: + if ( + "return_dict" + in inspect.signature(self.call_unprocessed).parameters + ): kwargs["return_dict"] = return_dict + else: + warnings.warn( + "Please add `return_dict` to the argument list of your " + "objective's `call_unprocessed` method.", + DeprecationWarning, + stacklevel=1, + ) + result = self.call_unprocessed( x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs ) @@ -211,6 +219,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/test/base/test_objective.py b/test/base/test_objective.py index 66b7928e6..b5d3ee930 100644 --- a/test/base/test_objective.py +++ b/test/base/test_objective.py @@ -449,7 +449,7 @@ def test_fds(fd_method, fd_delta): assert obj_fd.delta_fun.updates > 1 -def test_shared_return_dict(): +def test_call_unprocessed_return_dict(): class Objective0(pypesto.objective.ObjectiveBase): def call_unprocessed(self, *args, **kwargs): return {"fval": 0, "return_dict": "return_dict" in kwargs} @@ -458,15 +458,18 @@ def check_sensi_orders(self, *args, **kwargs): return True class Objective1(Objective0): - share_return_dict = True + def call_unprocessed(self, *args, return_dict: bool, **kwargs): + return {"fval": 0, "return_dict": return_dict} objective0 = Objective0() objective1 = Objective1() - result0 = objective0([0], return_dict=True) + with pytest.warns(DeprecationWarning, match="Please add `return_dict`"): + result0 = objective0([0], return_dict=True) + result1 = objective1([0], return_dict=True) # `return_dict` is not shared with `call_unprocessed` by default, assert not result0["return_dict"] - # but `ObjectiveBase.shared_return_dict = True` changes that. + # but is shared if the `call_unprocessed` signature supports it. assert result1["return_dict"] From 0a4cc1173186e732af72297e0dd1a4858294e3a5 Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 10:55:51 +0200 Subject: [PATCH 5/6] add `return_dict` to all `call_unprocessed` signatures --- pypesto/objective/aesara/base.py | 1 + pypesto/objective/base.py | 4 ++++ pypesto/objective/finite_difference.py | 1 + pypesto/objective/function.py | 1 + pypesto/objective/jax/base.py | 1 + pypesto/objective/priors.py | 1 + pypesto/objective/roadrunner/road_runner.py | 1 + 7 files changed, 10 insertions(+) diff --git a/pypesto/objective/aesara/base.py b/pypesto/objective/aesara/base.py index dc2528f10..ce7ad8c22 100644 --- a/pypesto/objective/aesara/base.py +++ b/pypesto/objective/aesara/base.py @@ -110,6 +110,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/base.py b/pypesto/objective/base.py index a851411e1..438833aac 100644 --- a/pypesto/objective/base.py +++ b/pypesto/objective/base.py @@ -233,6 +233,10 @@ def call_unprocessed( Specifies which sensitivities to compute, e.g. (0,1) -> fval, grad. mode: Whether to compute function values or residuals. + return_dict: + Whether the user requested additional information. Objectives can + use this to determine whether to e.g. return "full" or "minimal" + information. Returns ------- diff --git a/pypesto/objective/finite_difference.py b/pypesto/objective/finite_difference.py index 78d997a57..489bb1539 100644 --- a/pypesto/objective/finite_difference.py +++ b/pypesto/objective/finite_difference.py @@ -373,6 +373,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/function.py b/pypesto/objective/function.py index 696da36a8..1fce9549e 100644 --- a/pypesto/objective/function.py +++ b/pypesto/objective/function.py @@ -134,6 +134,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/jax/base.py b/pypesto/objective/jax/base.py index ea86b07a0..14204f32d 100644 --- a/pypesto/objective/jax/base.py +++ b/pypesto/objective/jax/base.py @@ -215,6 +215,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/priors.py b/pypesto/objective/priors.py index 3b8f177fe..e460c1572 100644 --- a/pypesto/objective/priors.py +++ b/pypesto/objective/priors.py @@ -72,6 +72,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: C.ModeType, + return_dict: bool, **kwargs, ) -> ResultDict: """ diff --git a/pypesto/objective/roadrunner/road_runner.py b/pypesto/objective/roadrunner/road_runner.py index 98a9a9e7e..f9f2f1d4b 100644 --- a/pypesto/objective/roadrunner/road_runner.py +++ b/pypesto/objective/roadrunner/road_runner.py @@ -96,6 +96,7 @@ def call_unprocessed( x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, + return_dict: bool, edatas: Optional[Sequence[ExpData]] = None, parameter_mapping: Optional[list[ParMappingDictQuadruple]] = None, ) -> dict: From b9ec235d58cf50419f5a3815f9beae7954edd20e Mon Sep 17 00:00:00 2001 From: dilpath Date: Wed, 3 Jul 2024 11:39:57 +0200 Subject: [PATCH 6/6] fix FDobjective --- pypesto/objective/finite_difference.py | 46 ++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/pypesto/objective/finite_difference.py b/pypesto/objective/finite_difference.py index 489bb1539..4bf7a505f 100644 --- a/pypesto/objective/finite_difference.py +++ b/pypesto/objective/finite_difference.py @@ -384,11 +384,17 @@ def call_unprocessed( """ if mode == MODE_FUN: result = self._call_mode_fun( - x=x, sensi_orders=sensi_orders, **kwargs + x=x, + sensi_orders=sensi_orders, + return_dict=return_dict, + **kwargs, ) elif mode == MODE_RES: result = self._call_mode_res( - x=x, sensi_orders=sensi_orders, **kwargs + x=x, + sensi_orders=sensi_orders, + return_dict=return_dict, + **kwargs, ) else: raise ValueError("This mode is not supported.") @@ -399,6 +405,7 @@ def _call_mode_fun( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> ResultDict: """Handle calls in function value mode. @@ -409,6 +416,7 @@ def _call_mode_fun( sensi_orders_obj, result = self._call_from_obj_fun( x=x, sensi_orders=sensi_orders, + return_dict=return_dict, **kwargs, ) @@ -430,13 +438,21 @@ def _call_mode_fun( def f_fval(x): """Short-hand to get a function value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(0,), mode=MODE_FUN, **kwargs + x=x, + sensi_orders=(0,), + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, )[FVAL] def f_grad(x): """Short-hand to get a gradient value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(1,), mode=MODE_FUN, **kwargs + x=x, + sensi_orders=(1,), + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, )[GRAD] # update delta vectors @@ -488,6 +504,7 @@ def _call_mode_res( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> ResultDict: """Handle calls in residual mode. @@ -498,6 +515,7 @@ def _call_mode_res( sensi_orders_obj, result = self._call_from_obj_res( x=x, sensi_orders=sensi_orders, + return_dict=return_dict, **kwargs, ) @@ -508,7 +526,11 @@ def _call_mode_res( def f_res(x): """Short-hand to get a function value.""" return self.obj.call_unprocessed( - x=x, sensi_orders=(0,), mode=MODE_RES, **kwargs + x=x, + sensi_orders=(0,), + mode=MODE_RES, + return_dict=return_dict, + **kwargs, )[RES] # update delta vector @@ -533,6 +555,7 @@ def _call_from_obj_fun( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> tuple[tuple[int, ...], ResultDict]: """ @@ -554,7 +577,11 @@ def _call_from_obj_fun( result = {} if sensi_orders_obj: result = self.obj.call_unprocessed( - x=x, sensi_orders=sensi_orders_obj, mode=MODE_FUN, **kwargs + x=x, + sensi_orders=sensi_orders_obj, + mode=MODE_FUN, + return_dict=return_dict, + **kwargs, ) return sensi_orders_obj, result @@ -562,6 +589,7 @@ def _call_from_obj_res( self, x: np.ndarray, sensi_orders: tuple[int, ...], + return_dict: bool, **kwargs, ) -> tuple[tuple[int, ...], ResultDict]: """ @@ -581,7 +609,11 @@ def _call_from_obj_res( result = {} if sensi_orders_obj: result = self.obj.call_unprocessed( - x=x, sensi_orders=sensi_orders_obj, mode=MODE_RES, **kwargs + x=x, + sensi_orders=sensi_orders_obj, + mode=MODE_RES, + return_dict=return_dict, + **kwargs, ) return sensi_orders_obj, result