Skip to content

Commit

Permalink
Hierarchical: avoid recomputing inner parameters if simulation failed (
Browse files Browse the repository at this point in the history
  • Loading branch information
dilpath authored Jul 4, 2024
1 parent 10d6092 commit ef34eeb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
46 changes: 25 additions & 21 deletions pypesto/objective/amici/amici.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import abc
import copy
import os
import tempfile
from collections import OrderedDict
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -34,6 +36,8 @@
)

if TYPE_CHECKING:
from ...hierarchical import InnerCalculatorCollector

try:
import amici
from amici.petab.parameter_mapping import ParameterMapping
Expand Down Expand Up @@ -61,7 +65,7 @@ def create_solver(self, model: AmiciModel) -> AmiciSolver:
"""Create an AMICI solver."""

@abc.abstractmethod
def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]:
def create_edatas(self, model: AmiciModel) -> Sequence[amici.ExpData]:
"""Create AMICI experimental data."""


Expand All @@ -74,17 +78,17 @@ def __init__(
self,
amici_model: AmiciModel,
amici_solver: AmiciSolver,
edatas: Union[Sequence["amici.ExpData"], "amici.ExpData"],
max_sensi_order: Optional[int] = None,
x_ids: Optional[Sequence[str]] = None,
x_names: Optional[Sequence[str]] = None,
parameter_mapping: Optional["ParameterMapping"] = None,
guess_steadystate: Optional[Optional[bool]] = None,
n_threads: Optional[int] = 1,
fim_for_hess: Optional[bool] = True,
amici_object_builder: Optional[AmiciObjectBuilder] = None,
calculator: Optional[AmiciCalculator] = None,
amici_reporting: Optional["amici.RDataReporting"] = None,
edatas: Sequence[amici.ExpData] | amici.ExpData,
max_sensi_order: int | None = None,
x_ids: Sequence[str] | None = None,
x_names: Sequence[str] | None = None,
parameter_mapping: ParameterMapping | None = None,
guess_steadystate: bool | None = None,
n_threads: int | None = 1,
fim_for_hess: bool | None = True,
amici_object_builder: AmiciObjectBuilder | None = None,
calculator: AmiciCalculator | InnerCalculatorCollector | None = None,
amici_reporting: amici.RDataReporting | None = None,
):
"""
Initialize objective.
Expand Down Expand Up @@ -278,7 +282,7 @@ def initialize(self):
self.reset_steadystate_guesses()
self.calculator.initialize()

def __deepcopy__(self, memodict: dict = None) -> "AmiciObjective":
def __deepcopy__(self, memodict: dict = None) -> AmiciObjective:
import amici

other = self.__class__.__new__(self.__class__)
Expand Down Expand Up @@ -423,9 +427,9 @@ def call_unprocessed(
sensi_orders: tuple[int, ...],
mode: ModeType,
return_dict: bool = False,
edatas: Sequence["amici.ExpData"] = None,
parameter_mapping: "ParameterMapping" = None,
amici_reporting: Optional["amici.RDataReporting"] = None,
edatas: Sequence[amici.ExpData] = None,
parameter_mapping: ParameterMapping = None,
amici_reporting: amici.RDataReporting | None = None,
):
"""
Call objective function without pre- or post-processing and formatting.
Expand Down Expand Up @@ -539,7 +543,7 @@ def store_steadystate_guess(
self,
condition_ix: int,
x_dct: dict,
rdata: "amici.ReturnData",
rdata: amici.ReturnData,
) -> None:
"""
Store condition parameter, steadystate and steadystate sensitivity.
Expand Down Expand Up @@ -584,9 +588,9 @@ def apply_custom_timepoints(self) -> None:

def set_custom_timepoints(
self,
timepoints: Sequence[Sequence[Union[float, int]]] = None,
timepoints_global: Sequence[Union[float, int]] = None,
) -> "AmiciObjective":
timepoints: Sequence[Sequence[float | int]] = None,
timepoints_global: Sequence[float | int] = None,
) -> AmiciObjective:
"""
Create a copy of this objective that is evaluated at custom timepoints.
Expand Down
2 changes: 1 addition & 1 deletion pypesto/optimize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def wrapped_minimize(
optimize_options=optimize_options,
)

if isinstance(problem, HierarchicalProblem):
if isinstance(problem, HierarchicalProblem) and result.x is not None:
# Call the objective to obtain inner parameters of
# the optimal outer optimization parameters
return_dict = problem.objective(
Expand Down

0 comments on commit ef34eeb

Please sign in to comment.