Skip to content

Commit

Permalink
Merge branch 'develop' into Codeonwers_proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulJonasJost authored Nov 7, 2023
2 parents da6dd6e + 8c30dc0 commit dbbe958
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 29 deletions.
114 changes: 86 additions & 28 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tempfile
import warnings
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
Expand All @@ -17,6 +18,7 @@
List,
Optional,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -44,15 +46,19 @@
from ..predict import AmiciPredictor
from ..problem import Problem
from ..result import PredictionResult
from ..startpoint import FunctionStartpoints, StartpointMethod
from ..startpoint import CheckedStartpoints, StartpointMethod

try:
import amici
import amici.parameter_mapping
import amici.petab_import
import amici.petab_objective
import petab
from petab.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID
from petab.C import (
ESTIMATE,
PREEQUILIBRATION_CONDITION_ID,
SIMULATION_CONDITION_ID,
)
from petab.models import MODEL_TYPE_SBML
except ImportError:
pass
Expand Down Expand Up @@ -643,39 +649,16 @@ def create_prior(self) -> Union[NegLogParameterPriors, None]:
else:
return None

def create_startpoint_method(
self, x_ids: Sequence[str] = None, **kwargs
) -> StartpointMethod:
def create_startpoint_method(self, **kwargs) -> StartpointMethod:
"""Create a startpoint method.
Parameters
----------
x_ids:
If provided, create a startpoint method that only samples the
parameters with the given IDs.
**kwargs:
Additional keyword arguments passed on to
:meth:`pypesto.startpoint.FunctionStartpoints.__init__`.
"""

def startpoint_method(n_starts: int, **kwargs):
startpoints = petab.sample_parameter_startpoints(
self.petab_problem.parameter_df, n_starts=n_starts
)
if x_ids is None:
return startpoints

# subset parameters according to the provided parameter IDs
from petab.C import ESTIMATE

parameter_df = self.petab_problem.parameter_df
pars_to_estimate = list(
parameter_df.index[parameter_df[ESTIMATE] == 1]
)
x_idxs = [pars_to_estimate.index(x_id) for x_id in x_ids]
return startpoints[:, x_idxs]

return FunctionStartpoints(function=startpoint_method, **kwargs)
return PetabStartpoints(petab_problem=self.petab_problem, **kwargs)

def create_problem(
self,
Expand Down Expand Up @@ -777,7 +760,7 @@ def create_problem(
x_scales=x_scales,
x_priors_defs=prior,
startpoint_method=self.create_startpoint_method(
x_ids=np.delete(x_ids, x_fixed_indices), **startpoint_kwargs
**startpoint_kwargs
),
**problem_kwargs,
)
Expand Down Expand Up @@ -971,3 +954,78 @@ def get_petab_non_quantitative_data_types(
if len(non_quantitative_data_types) == 0:
return None
return non_quantitative_data_types


class PetabStartpoints(CheckedStartpoints):
"""Startpoint method for PEtab problems.
Samples optimization startpoints from the distributions defined in the
provided PEtab problem. The PEtab-problem is copied.
"""

def __init__(self, petab_problem: petab.Problem, **kwargs):
super().__init__(**kwargs)
self._parameter_df = petab_problem.parameter_df.copy()
self._priors: Optional[List[Tuple]] = None
self._free_ids: Optional[List[str]] = None

def _setup(
self,
pypesto_problem: Problem,
):
"""Update priors if necessary.
Check if ``problem.x_free_indices`` changed since last call, and if so,
get the corresponding priors from PEtab.
"""
current_free_ids = np.asarray(pypesto_problem.x_names)[
pypesto_problem.x_free_indices
]

if (
self._priors is not None
and len(current_free_ids) == len(self._free_ids)
and np.all(current_free_ids == self._free_ids)
):
# no need to update
return

# update priors
self._free_ids = current_free_ids
id_to_prior = dict(
zip(
self._parameter_df.index[self._parameter_df[ESTIMATE] == 1],
petab.parameters.get_priors_from_df(
self._parameter_df, mode=petab.INITIALIZATION
),
)
)

self._priors = list(map(id_to_prior.__getitem__, current_free_ids))

def __call__(
self,
n_starts: int,
problem: Problem,
) -> np.ndarray:
"""Call the startpoint method."""
# Update the list of priors if needed
self._setup(pypesto_problem=problem)

return super().__call__(n_starts, problem)

def sample(
self,
n_starts: int,
lb: np.ndarray,
ub: np.ndarray,
) -> np.ndarray:
"""Actual startpoint sampling.
Must only be called through `self.__call__` to ensure that the list of priors
matches the currently free parameters in the :class:`pypesto.Problem`.
"""
sampler = partial(petab.sample_from_prior, n_starts=n_starts)
startpoints = list(map(sampler, self._priors))

return np.array(startpoints).T
23 changes: 22 additions & 1 deletion test/petab/test_petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,28 @@ def test_2_simulate(self):

self.assertTrue(np.isfinite(ret))

def test_3_optimize(self):
def test_3_startpoints(self):
# test startpoint sampling
for obj_edatas, importer in zip(self.obj_edatas, self.petab_importers):
obj = obj_edatas[0]
problem = importer.create_problem(obj)

# test for original problem
original_dim = problem.dim
startpoints = problem.startpoint_method(
n_starts=2, problem=problem
)
self.assertEqual(startpoints.shape, (2, problem.dim))

# test with fixed parameters
problem.fix_parameters(0, 1)
self.assertEqual(problem.dim, original_dim - 1)
startpoints = problem.startpoint_method(
n_starts=2, problem=problem
)
self.assertEqual(startpoints.shape, (2, problem.dim))

def test_4_optimize(self):
# run optimization
for obj_edatas, importer in zip(self.obj_edatas, self.petab_importers):
obj = obj_edatas[0]
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ envlist =
# Base-environment

[testenv]
passenv = AMICI_PARALLEL_COMPILE

# Sub-environments
# inherit settings defined in the base
Expand Down

0 comments on commit dbbe958

Please sign in to comment.