Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return_dict when an AmiciObjective is inside AggregatedObjective #1424

Merged
merged 7 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pypesto/objective/aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class AggregatedObjective(ObjectiveBase):
"""Aggregates multiple objectives into one objective."""

share_return_dict = True

def __init__(
self,
objectives: Sequence[ObjectiveBase],
Expand Down Expand Up @@ -92,6 +94,7 @@ def call_unprocessed(
sensi_orders: tuple[int, ...],
mode: ModeType,
kwargs_list: Sequence[dict[str, Any]] = None,
return_dict: bool = False,
**kwargs,
) -> ResultDict:
"""
Expand All @@ -113,6 +116,12 @@ def call_unprocessed(
"The length of `kwargs_list` must match the number of "
"objectives you are aggregating."
)
for objective_, objective_kwargs in zip(self._objectives, kwargs_list):
if objective_.share_return_dict:
objective_kwargs["return_dict"] = objective_kwargs.get(
"return_dict",
return_dict,
)
return aggregate_results(
[
objective.call_unprocessed(
Expand Down
44 changes: 15 additions & 29 deletions pypesto/objective/amici/amici.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
HistoryTypeError,
MemoryHistory,
)
from ..base import ObjectiveBase, ResultDict
from ..base import ObjectiveBase
from .amici_calculator import AmiciCalculator
from .amici_util import (
create_identity_parameter_mapping,
Expand Down Expand Up @@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]:
class AmiciObjective(ObjectiveBase):
"""Allows to create an objective directly from an amici model."""

share_return_dict = True

def __init__(
self,
amici_model: AmiciModel,
Expand Down Expand Up @@ -415,28 +417,6 @@ def check_mode(self, mode: ModeType) -> bool:
"""See `ObjectiveBase` documentation."""
return mode in [MODE_FUN, MODE_RES]

def __call__(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...] = (0,),
mode: ModeType = MODE_FUN,
return_dict: bool = False,
**kwargs,
) -> Union[float, np.ndarray, tuple, ResultDict]:
"""See `ObjectiveBase` documentation."""
import amici

# Use AMICI full reporting if amici.ReturnDatas are returned and no
# other reporting mode was set
if (
return_dict
and self.amici_reporting is None
and "amici_reporting" not in kwargs
):
kwargs["amici_reporting"] = amici.RDataReporting.full

return super().__call__(x, sensi_orders, mode, return_dict, **kwargs)

def call_unprocessed(
self,
x: np.ndarray,
Expand All @@ -445,6 +425,7 @@ def call_unprocessed(
edatas: Sequence["amici.ExpData"] = None,
parameter_mapping: "ParameterMapping" = None,
amici_reporting: Optional["amici.RDataReporting"] = None,
return_dict: bool = False,
):
"""
Call objective function without pre- or post-processing and formatting.
Expand All @@ -458,18 +439,23 @@ def call_unprocessed(

x_dct = self.par_arr_to_dct(x)

# only ask amici to compute required quantities
amici_reporting = (
self.amici_reporting
if amici_reporting is None
else amici_reporting
)
if amici_reporting is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be made clear in the documentation that amici_reporting takes precedence over return_dict imo.

amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
if return_dict:
# Use AMICI full reporting if amici.ReturnDatas are returned
# and no other reporting mode was set
amici_reporting = amici.RDataReporting.full
else:
# Else, only ask amici to compute required quantities
amici_reporting = (
amici.RDataReporting.likelihood
if mode == MODE_FUN
else amici.RDataReporting.residuals
)
self.amici_solver.setReturnDataReportingMode(amici_reporting)

# update steady state
Expand Down
7 changes: 7 additions & 0 deletions pypesto/objective/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class ObjectiveBase(ABC):
pre_post_processor:
Preprocess input values to and postprocess output values from
__call__. Configured in `update_from_problem()`.
share_return_dict:
Whether the objective uses `return_dict` in its `call_unprocessed`
method.
"""

share_return_dict: bool = False

def __init__(
self,
x_names: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -178,6 +183,8 @@ def __call__(
x_full = self.pre_post_processor.preprocess(x=x)

# compute result
if self.share_return_dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise there might be some argument error, because return_dict is not in every custom user objective.call_unprocessed signature. But I'll switch to Daniel's suggestion now

kwargs["return_dict"] = return_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if return_dict is in kwargs already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot be, since return_dict is it's own kwarg in this method's argument list.

result = self.call_unprocessed(
x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs
)
Expand Down
Loading