From ef34eeb14052d67a77ea6c65a47846d5dce21f04 Mon Sep 17 00:00:00 2001 From: Dilan Pathirana <59329744+dilpath@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:24:47 +0200 Subject: [PATCH] Hierarchical: avoid recomputing inner parameters if simulation failed (#1426) --- pypesto/objective/amici/amici.py | 46 +++++++++++++++++--------------- pypesto/optimize/optimizer.py | 2 +- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index b69fcfa6e..3af3f7023 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import copy import os @@ -5,7 +7,7 @@ from collections import OrderedDict from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Union import numpy as np @@ -34,6 +36,8 @@ ) if TYPE_CHECKING: + from ...hierarchical import InnerCalculatorCollector + try: import amici from amici.petab.parameter_mapping import ParameterMapping @@ -61,7 +65,7 @@ def create_solver(self, model: AmiciModel) -> AmiciSolver: """Create an AMICI solver.""" @abc.abstractmethod - def create_edatas(self, model: AmiciModel) -> Sequence["amici.ExpData"]: + def create_edatas(self, model: AmiciModel) -> Sequence[amici.ExpData]: """Create AMICI experimental data.""" @@ -74,17 +78,17 @@ def __init__( self, amici_model: AmiciModel, amici_solver: AmiciSolver, - edatas: Union[Sequence["amici.ExpData"], "amici.ExpData"], - max_sensi_order: Optional[int] = None, - x_ids: Optional[Sequence[str]] = None, - x_names: Optional[Sequence[str]] = None, - parameter_mapping: Optional["ParameterMapping"] = None, - guess_steadystate: Optional[Optional[bool]] = None, - n_threads: Optional[int] = 1, - fim_for_hess: Optional[bool] = True, - amici_object_builder: Optional[AmiciObjectBuilder] = None, - calculator: Optional[AmiciCalculator] = None, - amici_reporting: Optional["amici.RDataReporting"] = None, + edatas: Sequence[amici.ExpData] | amici.ExpData, + max_sensi_order: int | None = None, + x_ids: Sequence[str] | None = None, + x_names: Sequence[str] | None = None, + parameter_mapping: ParameterMapping | None = None, + guess_steadystate: bool | None = None, + n_threads: int | None = 1, + fim_for_hess: bool | None = True, + amici_object_builder: AmiciObjectBuilder | None = None, + calculator: AmiciCalculator | InnerCalculatorCollector | None = None, + amici_reporting: amici.RDataReporting | None = None, ): """ Initialize objective. @@ -278,7 +282,7 @@ def initialize(self): self.reset_steadystate_guesses() self.calculator.initialize() - def __deepcopy__(self, memodict: dict = None) -> "AmiciObjective": + def __deepcopy__(self, memodict: dict = None) -> AmiciObjective: import amici other = self.__class__.__new__(self.__class__) @@ -423,9 +427,9 @@ def call_unprocessed( sensi_orders: tuple[int, ...], mode: ModeType, return_dict: bool = False, - edatas: Sequence["amici.ExpData"] = None, - parameter_mapping: "ParameterMapping" = None, - amici_reporting: Optional["amici.RDataReporting"] = None, + edatas: Sequence[amici.ExpData] = None, + parameter_mapping: ParameterMapping = None, + amici_reporting: amici.RDataReporting | None = None, ): """ Call objective function without pre- or post-processing and formatting. @@ -539,7 +543,7 @@ def store_steadystate_guess( self, condition_ix: int, x_dct: dict, - rdata: "amici.ReturnData", + rdata: amici.ReturnData, ) -> None: """ Store condition parameter, steadystate and steadystate sensitivity. @@ -584,9 +588,9 @@ def apply_custom_timepoints(self) -> None: def set_custom_timepoints( self, - timepoints: Sequence[Sequence[Union[float, int]]] = None, - timepoints_global: Sequence[Union[float, int]] = None, - ) -> "AmiciObjective": + timepoints: Sequence[Sequence[float | int]] = None, + timepoints_global: Sequence[float | int] = None, + ) -> AmiciObjective: """ Create a copy of this objective that is evaluated at custom timepoints. diff --git a/pypesto/optimize/optimizer.py b/pypesto/optimize/optimizer.py index b7e7bc4f4..b05570e73 100644 --- a/pypesto/optimize/optimizer.py +++ b/pypesto/optimize/optimizer.py @@ -61,7 +61,7 @@ def wrapped_minimize( optimize_options=optimize_options, ) - if isinstance(problem, HierarchicalProblem): + if isinstance(problem, HierarchicalProblem) and result.x is not None: # Call the objective to obtain inner parameters of # the optimal outer optimization parameters return_dict = problem.objective(