-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix return_dict
when an AmiciObjective
is inside AggregatedObjective
#1424
Changes from 1 commit
5860c0f
27779ab
1ff18a9
5a592e1
0a4cc11
8b21c42
b9ec235
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
HistoryTypeError, | ||
MemoryHistory, | ||
) | ||
from ..base import ObjectiveBase, ResultDict | ||
from ..base import ObjectiveBase | ||
from .amici_calculator import AmiciCalculator | ||
from .amici_util import ( | ||
create_identity_parameter_mapping, | ||
|
@@ -68,6 +68,8 @@ def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]: | |
class AmiciObjective(ObjectiveBase): | ||
"""Allows to create an objective directly from an amici model.""" | ||
|
||
share_return_dict = True | ||
|
||
def __init__( | ||
self, | ||
amici_model: AmiciModel, | ||
|
@@ -415,28 +417,6 @@ def check_mode(self, mode: ModeType) -> bool: | |
"""See `ObjectiveBase` documentation.""" | ||
return mode in [MODE_FUN, MODE_RES] | ||
|
||
def __call__( | ||
self, | ||
x: np.ndarray, | ||
sensi_orders: tuple[int, ...] = (0,), | ||
mode: ModeType = MODE_FUN, | ||
return_dict: bool = False, | ||
**kwargs, | ||
) -> Union[float, np.ndarray, tuple, ResultDict]: | ||
"""See `ObjectiveBase` documentation.""" | ||
import amici | ||
|
||
# Use AMICI full reporting if amici.ReturnDatas are returned and no | ||
# other reporting mode was set | ||
if ( | ||
return_dict | ||
and self.amici_reporting is None | ||
and "amici_reporting" not in kwargs | ||
): | ||
kwargs["amici_reporting"] = amici.RDataReporting.full | ||
|
||
return super().__call__(x, sensi_orders, mode, return_dict, **kwargs) | ||
|
||
def call_unprocessed( | ||
self, | ||
x: np.ndarray, | ||
|
@@ -445,6 +425,7 @@ def call_unprocessed( | |
edatas: Sequence["amici.ExpData"] = None, | ||
parameter_mapping: "ParameterMapping" = None, | ||
amici_reporting: Optional["amici.RDataReporting"] = None, | ||
return_dict: bool = False, | ||
): | ||
""" | ||
Call objective function without pre- or post-processing and formatting. | ||
|
@@ -458,18 +439,23 @@ def call_unprocessed( | |
|
||
x_dct = self.par_arr_to_dct(x) | ||
|
||
# only ask amici to compute required quantities | ||
amici_reporting = ( | ||
self.amici_reporting | ||
if amici_reporting is None | ||
else amici_reporting | ||
) | ||
if amici_reporting is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be made clear in the documentation that amici_reporting takes precedence over return_dict imo. |
||
amici_reporting = ( | ||
amici.RDataReporting.likelihood | ||
if mode == MODE_FUN | ||
else amici.RDataReporting.residuals | ||
) | ||
if return_dict: | ||
# Use AMICI full reporting if amici.ReturnDatas are returned | ||
# and no other reporting mode was set | ||
amici_reporting = amici.RDataReporting.full | ||
else: | ||
# Else, only ask amici to compute required quantities | ||
amici_reporting = ( | ||
amici.RDataReporting.likelihood | ||
if mode == MODE_FUN | ||
else amici.RDataReporting.residuals | ||
) | ||
self.amici_solver.setReturnDataReportingMode(amici_reporting) | ||
|
||
# update steady state | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,8 +45,13 @@ class ObjectiveBase(ABC): | |
pre_post_processor: | ||
Preprocess input values to and postprocess output values from | ||
__call__. Configured in `update_from_problem()`. | ||
share_return_dict: | ||
Whether the objective uses `return_dict` in its `call_unprocessed` | ||
method. | ||
""" | ||
|
||
share_return_dict: bool = False | ||
|
||
def __init__( | ||
self, | ||
x_names: Optional[Sequence[str]] = None, | ||
|
@@ -178,6 +183,8 @@ def __call__( | |
x_full = self.pre_post_processor.preprocess(x=x) | ||
|
||
# compute result | ||
if self.share_return_dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise there might be some argument error, because |
||
kwargs["return_dict"] = return_dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It cannot be, since |
||
result = self.call_unprocessed( | ||
x=x_full, sensi_orders=sensi_orders, mode=mode, **kwargs | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as in the base objective (see above or below comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually argue that the outer
return_dict
should take precedence over innerreturn_dict
and that we don't need theif objective_.share_return_dict:
check at all.