Skip to content

Commit

Permalink
Fix startpoint sampling for PEtab-derived problems with fixed paramet…
Browse files Browse the repository at this point in the history
…ers (#1169)

Startpoint sampling for `PetabImporter`-derived problems didn't work correctly in case any parameters were fixed in addition to those marked `estimate=0` in the underlying PEtab problem.

Fixing any parameters after the construction of the `pypesto.Problem` and the corresponding startpoint method would lead to errors during startpoint sampling because the list of fixed parameters was never updated.

In order to fix that, we need to have the current `pypesto.Problem` available for startpoint sampling to get access to the currently fixed parameters. Accessing `pypesto.Problem` is not compatible with the current `FunctionStartpoints`. Therefore, we derive a new `PetabStartpoints` class from `CheckedStartpoints.sample` that will allow forwarding/accessing the `Problem`.

Co-authored-by: Fabian Fröhlich <fabian@schaluck.com>
  • Loading branch information
dweindl and FFroehlich authored Nov 3, 2023
1 parent ddf7798 commit 8c30dc0
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 29 deletions.
114 changes: 86 additions & 28 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tempfile
import warnings
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
Expand All @@ -17,6 +18,7 @@
List,
Optional,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -44,15 +46,19 @@
from ..predict import AmiciPredictor
from ..problem import Problem
from ..result import PredictionResult
from ..startpoint import FunctionStartpoints, StartpointMethod
from ..startpoint import CheckedStartpoints, StartpointMethod

try:
import amici
import amici.parameter_mapping
import amici.petab_import
import amici.petab_objective
import petab
from petab.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID
from petab.C import (
ESTIMATE,
PREEQUILIBRATION_CONDITION_ID,
SIMULATION_CONDITION_ID,
)
from petab.models import MODEL_TYPE_SBML
except ImportError:
pass
Expand Down Expand Up @@ -643,39 +649,16 @@ def create_prior(self) -> Union[NegLogParameterPriors, None]:
else:
return None

def create_startpoint_method(
self, x_ids: Sequence[str] = None, **kwargs
) -> StartpointMethod:
def create_startpoint_method(self, **kwargs) -> StartpointMethod:
"""Create a startpoint method.
Parameters
----------
x_ids:
If provided, create a startpoint method that only samples the
parameters with the given IDs.
**kwargs:
Additional keyword arguments passed on to
:meth:`pypesto.startpoint.FunctionStartpoints.__init__`.
"""

def startpoint_method(n_starts: int, **kwargs):
startpoints = petab.sample_parameter_startpoints(
self.petab_problem.parameter_df, n_starts=n_starts
)
if x_ids is None:
return startpoints

# subset parameters according to the provided parameter IDs
from petab.C import ESTIMATE

parameter_df = self.petab_problem.parameter_df
pars_to_estimate = list(
parameter_df.index[parameter_df[ESTIMATE] == 1]
)
x_idxs = [pars_to_estimate.index(x_id) for x_id in x_ids]
return startpoints[:, x_idxs]

return FunctionStartpoints(function=startpoint_method, **kwargs)
return PetabStartpoints(petab_problem=self.petab_problem, **kwargs)

def create_problem(
self,
Expand Down Expand Up @@ -777,7 +760,7 @@ def create_problem(
x_scales=x_scales,
x_priors_defs=prior,
startpoint_method=self.create_startpoint_method(
x_ids=np.delete(x_ids, x_fixed_indices), **startpoint_kwargs
**startpoint_kwargs
),
**problem_kwargs,
)
Expand Down Expand Up @@ -971,3 +954,78 @@ def get_petab_non_quantitative_data_types(
if len(non_quantitative_data_types) == 0:
return None
return non_quantitative_data_types


class PetabStartpoints(CheckedStartpoints):
"""Startpoint method for PEtab problems.
Samples optimization startpoints from the distributions defined in the
provided PEtab problem. The PEtab-problem is copied.
"""

def __init__(self, petab_problem: petab.Problem, **kwargs):
super().__init__(**kwargs)
self._parameter_df = petab_problem.parameter_df.copy()
self._priors: Optional[List[Tuple]] = None
self._free_ids: Optional[List[str]] = None

def _setup(
self,
pypesto_problem: Problem,
):
"""Update priors if necessary.
Check if ``problem.x_free_indices`` changed since last call, and if so,
get the corresponding priors from PEtab.
"""
current_free_ids = np.asarray(pypesto_problem.x_names)[
pypesto_problem.x_free_indices
]

if (
self._priors is not None
and len(current_free_ids) == len(self._free_ids)
and np.all(current_free_ids == self._free_ids)
):
# no need to update
return

# update priors
self._free_ids = current_free_ids
id_to_prior = dict(
zip(
self._parameter_df.index[self._parameter_df[ESTIMATE] == 1],
petab.parameters.get_priors_from_df(
self._parameter_df, mode=petab.INITIALIZATION
),
)
)

self._priors = list(map(id_to_prior.__getitem__, current_free_ids))

def __call__(
self,
n_starts: int,
problem: Problem,
) -> np.ndarray:
"""Call the startpoint method."""
# Update the list of priors if needed
self._setup(pypesto_problem=problem)

return super().__call__(n_starts, problem)

def sample(
self,
n_starts: int,
lb: np.ndarray,
ub: np.ndarray,
) -> np.ndarray:
"""Actual startpoint sampling.
Must only be called through `self.__call__` to ensure that the list of priors
matches the currently free parameters in the :class:`pypesto.Problem`.
"""
sampler = partial(petab.sample_from_prior, n_starts=n_starts)
startpoints = list(map(sampler, self._priors))

return np.array(startpoints).T
23 changes: 22 additions & 1 deletion test/petab/test_petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,28 @@ def test_2_simulate(self):

self.assertTrue(np.isfinite(ret))

def test_3_optimize(self):
def test_3_startpoints(self):
# test startpoint sampling
for obj_edatas, importer in zip(self.obj_edatas, self.petab_importers):
obj = obj_edatas[0]
problem = importer.create_problem(obj)

# test for original problem
original_dim = problem.dim
startpoints = problem.startpoint_method(
n_starts=2, problem=problem
)
self.assertEqual(startpoints.shape, (2, problem.dim))

# test with fixed parameters
problem.fix_parameters(0, 1)
self.assertEqual(problem.dim, original_dim - 1)
startpoints = problem.startpoint_method(
n_starts=2, problem=problem
)
self.assertEqual(startpoints.shape, (2, problem.dim))

def test_4_optimize(self):
# run optimization
for obj_edatas, importer in zip(self.obj_edatas, self.petab_importers):
obj = obj_edatas[0]
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ envlist =
# Base-environment

[testenv]
passenv = AMICI_PARALLEL_COMPILE

# Sub-environments
# inherit settings defined in the base
Expand Down

0 comments on commit 8c30dc0

Please sign in to comment.