Skip to content

Commit

Permalink
Require return_dict in ObjectiveBase.call_unprocessed (fixes AMIC…
Browse files Browse the repository at this point in the history
…I posterior RData) (#1424)
  • Loading branch information
dilpath authored Jul 3, 2024
1 parent 5a8e014 commit 10d6092
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 36 deletions.
1 change: 1 addition & 0 deletions pypesto/objective/aesara/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
17 changes: 17 additions & 0 deletions pypesto/objective/aggregated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any
Expand Down Expand Up @@ -92,6 +94,7 @@ def call_unprocessed(
sensi_orders: tuple[int, ...],
mode: ModeType,
kwargs_list: Sequence[dict[str, Any]] = None,
return_dict: bool = False,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -113,6 +116,20 @@ def call_unprocessed(
"The length of `kwargs_list` must match the number of "
"objectives you are aggregating."
)
for objective_, objective_kwargs in zip(self._objectives, kwargs_list):
if (
"return_dict"
in inspect.signature(objective_.call_unprocessed).parameters
):
objective_kwargs["return_dict"] = return_dict
else:
warnings.warn(
"Please add `return_dict` to the argument list of your "
"objective's `call_unprocessed` method. "
f"Current objective: `{type(objective_)}`.",
DeprecationWarning,
stacklevel=1,
)
return aggregate_results(
[
objective.call_unprocessed(
Expand Down
44 changes: 15 additions & 29 deletions pypesto/objective/amici/amici.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
HistoryTypeError,
MemoryHistory,
)
from ..base import ObjectiveBase, ResultDict
from ..base import ObjectiveBase
from .amici_calculator import AmiciCalculator
from .amici_util import (
create_identity_parameter_mapping,
Expand Down Expand Up @@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]:
class AmiciObjective(ObjectiveBase):
"""Allows to create an objective directly from an amici model."""

share_return_dict = True

def __init__(
self,
amici_model: AmiciModel,
Expand Down Expand Up @@ -415,33 +417,12 @@ def check_mode(self, mode: ModeType) -> bool:
"""See `ObjectiveBase` documentation."""
return mode in [MODE_FUN, MODE_RES]

def __call__(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...] = (0,),
mode: ModeType = MODE_FUN,
return_dict: bool = False,
**kwargs,
) -> Union[float, np.ndarray, tuple, ResultDict]:
"""See `ObjectiveBase` documentation."""
import amici

# Use AMICI full reporting if amici.ReturnDatas are returned and no
# other reporting mode was set
if (
return_dict
and self.amici_reporting is None
and "amici_reporting" not in kwargs
):
kwargs["amici_reporting"] = amici.RDataReporting.full

return super().__call__(x, sensi_orders, mode, return_dict, **kwargs)

def call_unprocessed(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool = False,
edatas: Sequence["amici.ExpData"] = None,
parameter_mapping: "ParameterMapping" = None,
amici_reporting: Optional["amici.RDataReporting"] = None,
Expand All @@ -458,18 +439,23 @@ def call_unprocessed(

x_dct = self.par_arr_to_dct(x)

# only ask amici to compute required quantities
amici_reporting = (
self.amici_reporting
if amici_reporting is None
else amici_reporting
)
if amici_reporting is None:
amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
if return_dict:
# Use AMICI full reporting if amici.ReturnDatas are returned
# and no other reporting mode was set
amici_reporting = amici.RDataReporting.full
else:
# Else, only ask amici to compute required quantities
amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
self.amici_solver.setReturnDataReportingMode(amici_reporting)

# update steady state
Expand Down
20 changes: 20 additions & 0 deletions pypesto/objective/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import Optional, Union
Expand Down Expand Up @@ -178,6 +180,19 @@ def __call__(
x_full = self.pre_post_processor.preprocess(x=x)

# compute result
if (
"return_dict"
in inspect.signature(self.call_unprocessed).parameters
):
kwargs["return_dict"] = return_dict
else:
warnings.warn(
"Please add `return_dict` to the argument list of your "
"objective's `call_unprocessed` method.",
DeprecationWarning,
stacklevel=1,
)

result = self.call_unprocessed(
x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs
)
Expand All @@ -204,6 +219,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -217,6 +233,10 @@ def call_unprocessed(
Specifies which sensitivities to compute, e.g. (0,1) -> fval, grad.
mode:
Whether to compute function values or residuals.
return_dict:
Whether the user requested additional information. Objectives can
use this to determine whether to e.g. return "full" or "minimal"
information.
Returns
-------
Expand Down
47 changes: 40 additions & 7 deletions pypesto/objective/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -383,11 +384,17 @@ def call_unprocessed(
"""
if mode == MODE_FUN:
result = self._call_mode_fun(
x=x, sensi_orders=sensi_orders, **kwargs
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)
elif mode == MODE_RES:
result = self._call_mode_res(
x=x, sensi_orders=sensi_orders, **kwargs
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)
else:
raise ValueError("This mode is not supported.")
Expand All @@ -398,6 +405,7 @@ def _call_mode_fun(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> ResultDict:
"""Handle calls in function value mode.
Expand All @@ -408,6 +416,7 @@ def _call_mode_fun(
sensi_orders_obj, result = self._call_from_obj_fun(
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)

Expand All @@ -429,13 +438,21 @@ def _call_mode_fun(
def f_fval(x):
"""Short-hand to get a function value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(0,), mode=MODE_FUN, **kwargs
x=x,
sensi_orders=(0,),
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)[FVAL]

def f_grad(x):
"""Short-hand to get a gradient value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(1,), mode=MODE_FUN, **kwargs
x=x,
sensi_orders=(1,),
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)[GRAD]

# update delta vectors
Expand Down Expand Up @@ -487,6 +504,7 @@ def _call_mode_res(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> ResultDict:
"""Handle calls in residual mode.
Expand All @@ -497,6 +515,7 @@ def _call_mode_res(
sensi_orders_obj, result = self._call_from_obj_res(
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)

Expand All @@ -507,7 +526,11 @@ def _call_mode_res(
def f_res(x):
"""Short-hand to get a function value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(0,), mode=MODE_RES, **kwargs
x=x,
sensi_orders=(0,),
mode=MODE_RES,
return_dict=return_dict,
**kwargs,
)[RES]

# update delta vector
Expand All @@ -532,6 +555,7 @@ def _call_from_obj_fun(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> tuple[tuple[int, ...], ResultDict]:
"""
Expand All @@ -553,14 +577,19 @@ def _call_from_obj_fun(
result = {}
if sensi_orders_obj:
result = self.obj.call_unprocessed(
x=x, sensi_orders=sensi_orders_obj, mode=MODE_FUN, **kwargs
x=x,
sensi_orders=sensi_orders_obj,
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)
return sensi_orders_obj, result

def _call_from_obj_res(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> tuple[tuple[int, ...], ResultDict]:
"""
Expand All @@ -580,7 +609,11 @@ def _call_from_obj_res(
result = {}
if sensi_orders_obj:
result = self.obj.call_unprocessed(
x=x, sensi_orders=sensi_orders_obj, mode=MODE_RES, **kwargs
x=x,
sensi_orders=sensi_orders_obj,
mode=MODE_RES,
return_dict=return_dict,
**kwargs,
)
return sensi_orders_obj, result

Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: C.ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/roadrunner/road_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
edatas: Optional[Sequence[ExpData]] = None,
parameter_mapping: Optional[list[ParMappingDictQuadruple]] = None,
) -> dict:
Expand Down
26 changes: 26 additions & 0 deletions test/base/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,29 @@ def test_fds(fd_method, fd_delta):
assert obj_fd.delta_fun.updates == 0
else:
assert obj_fd.delta_fun.updates > 1


def test_call_unprocessed_return_dict():
class Objective0(pypesto.objective.ObjectiveBase):
def call_unprocessed(self, *args, **kwargs):
return {"fval": 0, "return_dict": "return_dict" in kwargs}

def check_sensi_orders(self, *args, **kwargs):
return True

class Objective1(Objective0):
def call_unprocessed(self, *args, return_dict: bool, **kwargs):
return {"fval": 0, "return_dict": return_dict}

objective0 = Objective0()
objective1 = Objective1()

with pytest.warns(DeprecationWarning, match="Please add `return_dict`"):
result0 = objective0([0], return_dict=True)

result1 = objective1([0], return_dict=True)

# `return_dict` is not shared with `call_unprocessed` by default,
assert not result0["return_dict"]
# but is shared if the `call_unprocessed` signature supports it.
assert result1["return_dict"]

0 comments on commit 10d6092

Please sign in to comment.