Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return_dict when an AmiciObjective is inside AggregatedObjective #1424

Merged
merged 7 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pypesto/objective/aesara/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
17 changes: 17 additions & 0 deletions pypesto/objective/aggregated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any
Expand Down Expand Up @@ -92,6 +94,7 @@ def call_unprocessed(
sensi_orders: tuple[int, ...],
mode: ModeType,
kwargs_list: Sequence[dict[str, Any]] = None,
return_dict: bool = False,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -113,6 +116,20 @@ def call_unprocessed(
"The length of `kwargs_list` must match the number of "
"objectives you are aggregating."
)
for objective_, objective_kwargs in zip(self._objectives, kwargs_list):
if (
"return_dict"
in inspect.signature(objective_.call_unprocessed).parameters
):
objective_kwargs["return_dict"] = return_dict
else:
warnings.warn(
"Please add `return_dict` to the argument list of your "
"objective's `call_unprocessed` method. "
f"Current objective: `{type(objective_)}`.",
DeprecationWarning,
stacklevel=1,
)
return aggregate_results(
[
objective.call_unprocessed(
Expand Down
44 changes: 15 additions & 29 deletions pypesto/objective/amici/amici.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
HistoryTypeError,
MemoryHistory,
)
from ..base import ObjectiveBase, ResultDict
from ..base import ObjectiveBase
from .amici_calculator import AmiciCalculator
from .amici_util import (
create_identity_parameter_mapping,
Expand Down Expand Up @@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]:
class AmiciObjective(ObjectiveBase):
"""Allows to create an objective directly from an amici model."""

share_return_dict = True

def __init__(
self,
amici_model: AmiciModel,
Expand Down Expand Up @@ -415,33 +417,12 @@ def check_mode(self, mode: ModeType) -> bool:
"""See `ObjectiveBase` documentation."""
return mode in [MODE_FUN, MODE_RES]

def __call__(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...] = (0,),
mode: ModeType = MODE_FUN,
return_dict: bool = False,
**kwargs,
) -> Union[float, np.ndarray, tuple, ResultDict]:
"""See `ObjectiveBase` documentation."""
import amici

# Use AMICI full reporting if amici.ReturnDatas are returned and no
# other reporting mode was set
if (
return_dict
and self.amici_reporting is None
and "amici_reporting" not in kwargs
):
kwargs["amici_reporting"] = amici.RDataReporting.full

return super().__call__(x, sensi_orders, mode, return_dict, **kwargs)

def call_unprocessed(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool = False,
edatas: Sequence["amici.ExpData"] = None,
parameter_mapping: "ParameterMapping" = None,
amici_reporting: Optional["amici.RDataReporting"] = None,
Expand All @@ -458,18 +439,23 @@ def call_unprocessed(

x_dct = self.par_arr_to_dct(x)

# only ask amici to compute required quantities
amici_reporting = (
self.amici_reporting
if amici_reporting is None
else amici_reporting
)
if amici_reporting is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be made clear in the documentation that amici_reporting takes precedence over return_dict imo.

amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
if return_dict:
# Use AMICI full reporting if amici.ReturnDatas are returned
# and no other reporting mode was set
amici_reporting = amici.RDataReporting.full
else:
# Else, only ask amici to compute required quantities
amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
self.amici_solver.setReturnDataReportingMode(amici_reporting)

# update steady state
Expand Down
20 changes: 20 additions & 0 deletions pypesto/objective/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import Optional, Union
Expand Down Expand Up @@ -178,6 +180,19 @@ def __call__(
x_full = self.pre_post_processor.preprocess(x=x)

# compute result
if (
"return_dict"
in inspect.signature(self.call_unprocessed).parameters
):
kwargs["return_dict"] = return_dict
else:
warnings.warn(
"Please add `return_dict` to the argument list of your "
"objective's `call_unprocessed` method.",
DeprecationWarning,
stacklevel=1,
)

result = self.call_unprocessed(
x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs
)
Expand All @@ -204,6 +219,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -217,6 +233,10 @@ def call_unprocessed(
Specifies which sensitivities to compute, e.g. (0,1) -> fval, grad.
mode:
Whether to compute function values or residuals.
return_dict:
Whether the user requested additional information. Objectives can
use this to determine whether to e.g. return "full" or "minimal"
information.

Returns
-------
Expand Down
47 changes: 40 additions & 7 deletions pypesto/objective/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -383,11 +384,17 @@ def call_unprocessed(
"""
if mode == MODE_FUN:
result = self._call_mode_fun(
x=x, sensi_orders=sensi_orders, **kwargs
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)
elif mode == MODE_RES:
result = self._call_mode_res(
x=x, sensi_orders=sensi_orders, **kwargs
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)
else:
raise ValueError("This mode is not supported.")
Expand All @@ -398,6 +405,7 @@ def _call_mode_fun(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> ResultDict:
"""Handle calls in function value mode.
Expand All @@ -408,6 +416,7 @@ def _call_mode_fun(
sensi_orders_obj, result = self._call_from_obj_fun(
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)

Expand All @@ -429,13 +438,21 @@ def _call_mode_fun(
def f_fval(x):
"""Short-hand to get a function value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(0,), mode=MODE_FUN, **kwargs
x=x,
sensi_orders=(0,),
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)[FVAL]

def f_grad(x):
"""Short-hand to get a gradient value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(1,), mode=MODE_FUN, **kwargs
x=x,
sensi_orders=(1,),
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)[GRAD]

# update delta vectors
Expand Down Expand Up @@ -487,6 +504,7 @@ def _call_mode_res(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> ResultDict:
"""Handle calls in residual mode.
Expand All @@ -497,6 +515,7 @@ def _call_mode_res(
sensi_orders_obj, result = self._call_from_obj_res(
x=x,
sensi_orders=sensi_orders,
return_dict=return_dict,
**kwargs,
)

Expand All @@ -507,7 +526,11 @@ def _call_mode_res(
def f_res(x):
"""Short-hand to get a function value."""
return self.obj.call_unprocessed(
x=x, sensi_orders=(0,), mode=MODE_RES, **kwargs
x=x,
sensi_orders=(0,),
mode=MODE_RES,
return_dict=return_dict,
**kwargs,
)[RES]

# update delta vector
Expand All @@ -532,6 +555,7 @@ def _call_from_obj_fun(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> tuple[tuple[int, ...], ResultDict]:
"""
Expand All @@ -553,14 +577,19 @@ def _call_from_obj_fun(
result = {}
if sensi_orders_obj:
result = self.obj.call_unprocessed(
x=x, sensi_orders=sensi_orders_obj, mode=MODE_FUN, **kwargs
x=x,
sensi_orders=sensi_orders_obj,
mode=MODE_FUN,
return_dict=return_dict,
**kwargs,
)
return sensi_orders_obj, result

def _call_from_obj_res(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
return_dict: bool,
**kwargs,
) -> tuple[tuple[int, ...], ResultDict]:
"""
Expand All @@ -580,7 +609,11 @@ def _call_from_obj_res(
result = {}
if sensi_orders_obj:
result = self.obj.call_unprocessed(
x=x, sensi_orders=sensi_orders_obj, mode=MODE_RES, **kwargs
x=x,
sensi_orders=sensi_orders_obj,
mode=MODE_RES,
return_dict=return_dict,
**kwargs,
)
return sensi_orders_obj, result

Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: C.ModeType,
return_dict: bool,
**kwargs,
) -> ResultDict:
"""
Expand Down
1 change: 1 addition & 0 deletions pypesto/objective/roadrunner/road_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def call_unprocessed(
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool,
edatas: Optional[Sequence[ExpData]] = None,
parameter_mapping: Optional[list[ParMappingDictQuadruple]] = None,
) -> dict:
Expand Down
26 changes: 26 additions & 0 deletions test/base/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,29 @@ def test_fds(fd_method, fd_delta):
assert obj_fd.delta_fun.updates == 0
else:
assert obj_fd.delta_fun.updates > 1


def test_call_unprocessed_return_dict():
class Objective0(pypesto.objective.ObjectiveBase):
def call_unprocessed(self, *args, **kwargs):
return {"fval": 0, "return_dict": "return_dict" in kwargs}

def check_sensi_orders(self, *args, **kwargs):
return True

class Objective1(Objective0):
def call_unprocessed(self, *args, return_dict: bool, **kwargs):
return {"fval": 0, "return_dict": return_dict}

objective0 = Objective0()
objective1 = Objective1()

with pytest.warns(DeprecationWarning, match="Please add `return_dict`"):
result0 = objective0([0], return_dict=True)

result1 = objective1([0], return_dict=True)

# `return_dict` is not shared with `call_unprocessed` by default,
assert not result0["return_dict"]
# but is shared if the `call_unprocessed` signature supports it.
assert result1["return_dict"]
Loading