From bc1d35a7ddc8e9ab4ef5ac05ccfd0515802fc409 Mon Sep 17 00:00:00 2001 From: Doresic Date: Thu, 29 Feb 2024 17:06:03 +0100 Subject: [PATCH 1/6] Add spline knots to the optimization result Adds spline knots to the optimization result so they can be easily obtained after optimization. This then allows easy further use, if needed. --- pypesto/C.py | 2 +- .../inner_calculator_collector.py | 19 +++++-- pypesto/hierarchical/ordinal/calculator.py | 8 +-- pypesto/hierarchical/relative/calculator.py | 5 +- .../semiquantitative/calculator.py | 10 ++-- .../hierarchical/semiquantitative/problem.py | 42 ++++++++++++++++ .../hierarchical/semiquantitative/solver.py | 3 +- pypesto/history/csv.py | 5 -- pypesto/objective/amici/amici.py | 7 ++- pypesto/optimize/optimizer.py | 8 ++- pypesto/problem/hierarchical.py | 9 ++++ pypesto/result/optimize.py | 1 + pypesto/visualize/spline_approximation.py | 50 ++++++++++++++++--- 13 files changed, 130 insertions(+), 39 deletions(-) diff --git a/pypesto/C.py b/pypesto/C.py index 25f1468cb..f47455eb8 100644 --- a/pypesto/C.py +++ b/pypesto/C.py @@ -89,7 +89,6 @@ class EnsembleType(Enum): INNER_PARAMETERS = 'inner_parameters' INNER_RDATAS = 'inner_rdatas' PARAMETER_TYPE = 'parameterType' -X_INNER_OPT = 'x_inner_opt' RELATIVE = 'relative' @@ -207,6 +206,7 @@ class InnerParameterType(str, Enum): MIN_SIM_RANGE = 1e-16 SPLINE_PAR_TYPE = 'spline' +SPLINE_KNOTS = 'spline_knots' N_SPLINE_PARS = 'n_spline_pars' DATAPOINTS = 'datapoints' MIN_DATAPOINT = 'min_datapoint' diff --git a/pypesto/hierarchical/inner_calculator_collector.py b/pypesto/hierarchical/inner_calculator_collector.py index 72d21af2d..49e6c3e64 100644 --- a/pypesto/hierarchical/inner_calculator_collector.py +++ b/pypesto/hierarchical/inner_calculator_collector.py @@ -31,9 +31,10 @@ RES, SEMIQUANTITATIVE, SPLINE_APPROXIMATION_OPTIONS, + SPLINE_KNOTS, SPLINE_RATIO, SRES, - X_INNER_OPT, + InnerParameterType, ModeType, ) from ..objective.amici.amici_calculator import AmiciCalculator @@ -109,6 +110,7 @@ def __init__( self.quantitative_data_mask = self._get_quantitative_data_mask(edatas) self._known_least_squares_safe = False + self.semiquant_observable_ids = None def initialize(self): """Initialize.""" @@ -179,6 +181,12 @@ def construct_inner_calculators( semiquant_problem.get_noise_dummy_values(scaled=True) ) self.inner_calculators.append(semiquant_calculator) + self.semiquant_observable_ids = [ + model.getObservableIds()[group - 1] + for group in semiquant_problem.get_groups_for_xs( + InnerParameterType.SPLINE + ) + ] if self.data_types - { RELATIVE, @@ -384,7 +392,7 @@ def __call__( nllh, snllh, s2nllh, chi2, res, sres = init_return_values( sensi_orders, mode, dim ) - all_inner_pars = {} + spline_knots = None interpretable_inner_pars = [] # set order in solver @@ -423,7 +431,7 @@ def __call__( RES: res, SRES: sres, RDATAS: rdatas, - X_INNER_OPT: all_inner_pars, + SPLINE_KNOTS: None, INNER_PARAMETERS: None, } ret[FVAL] = np.inf @@ -475,9 +483,10 @@ def __call__( if 1 in sensi_orders: snllh += inner_result[GRAD] - all_inner_pars.update(inner_result[X_INNER_OPT]) if INNER_PARAMETERS in inner_result: interpretable_inner_pars.extend(inner_result[INNER_PARAMETERS]) + if SPLINE_KNOTS in inner_result: + spline_knots = inner_result[SPLINE_KNOTS] # add the quantitative data contribution if self.quantitative_data_mask is not None: @@ -508,7 +517,7 @@ def __call__( # Add inner parameters to return dict # only if the objective value improved. if ret[FVAL] < self.best_fval: - ret[X_INNER_OPT] = all_inner_pars + ret[SPLINE_KNOTS] = spline_knots ret[INNER_PARAMETERS] = ( interpretable_inner_pars if len(interpretable_inner_pars) > 0 diff --git a/pypesto/hierarchical/ordinal/calculator.py b/pypesto/hierarchical/ordinal/calculator.py index 757f04559..d2405964c 100644 --- a/pypesto/hierarchical/ordinal/calculator.py +++ b/pypesto/hierarchical/ordinal/calculator.py @@ -17,7 +17,6 @@ RDATAS, RES, SRES, - X_INNER_OPT, ) from ...objective.amici.amici_calculator import ( AmiciCalculator, @@ -126,7 +125,7 @@ def __call__( Returns ------- inner_result: - A dict containing the calculation results: FVAL, GRAD, RDATAS and X_INNER_OPT. + A dict containing the calculation results: FVAL, GRAD, RDATAS. """ if mode == MODE_RES: raise ValueError( @@ -178,7 +177,6 @@ def __call__( RES: res, SRES: sres, RDATAS: rdatas, - X_INNER_OPT: self.inner_problem.get_inner_parameter_dictionary(), } # if any amici simulation failed, it's unlikely we can compute @@ -201,13 +199,9 @@ def __call__( inner_result[FVAL] = self.inner_solver.calculate_obj_function( x_inner_opt ) - inner_result[ - X_INNER_OPT - ] = self.inner_problem.get_inner_parameter_dictionary() # calculate analytical gradients if requested if sensi_order > 0: - # print([opt['fun'] for opt in x_inner_opt]) sy = [rdata[AMICI_SY] for rdata in rdatas] ssigma = [rdata[AMICI_SSIGMAY] for rdata in rdatas] inner_result[GRAD] = self.inner_solver.calculate_gradients( diff --git a/pypesto/hierarchical/relative/calculator.py b/pypesto/hierarchical/relative/calculator.py index 6a2a41fe0..56b8819bb 100644 --- a/pypesto/hierarchical/relative/calculator.py +++ b/pypesto/hierarchical/relative/calculator.py @@ -25,7 +25,6 @@ RDATAS, RES, SRES, - X_INNER_OPT, ModeType, ) from ...objective.amici.amici_calculator import ( @@ -123,7 +122,7 @@ def __call__( Returns ------- inner_result: - A dict containing the calculation results: FVAL, GRAD, RDATAS and X_INNER_OPT. + A dict containing the calculation results: FVAL, GRAD, RDATAS and INNER_PARAMETERS. """ if not self.inner_problem.check_edatas(edatas=edatas): raise ValueError( @@ -164,11 +163,9 @@ def __call__( rdatas=rdatas, ) - inner_result[X_INNER_OPT] = {} inner_result[INNER_PARAMETERS] = np.array( [inner_parameters[x_id] for x_id in self.inner_problem.get_x_ids()] ) - # print("relative_inner_parameters: ", inner_parameters) return inner_result diff --git a/pypesto/hierarchical/semiquantitative/calculator.py b/pypesto/hierarchical/semiquantitative/calculator.py index f2ffcd039..c1daf1e24 100644 --- a/pypesto/hierarchical/semiquantitative/calculator.py +++ b/pypesto/hierarchical/semiquantitative/calculator.py @@ -15,8 +15,8 @@ MODE_RES, RDATAS, RES, + SPLINE_KNOTS, SRES, - X_INNER_OPT, ) from ...objective.amici.amici_calculator import ( AmiciCalculator, @@ -119,7 +119,8 @@ def __call__( Returns ------- inner_result: - A dict containing the calculation results: FVAL, GRAD, RDATAS and X_INNER_OPT. + A dict containing the calculation results: FVAL, GRAD, RDATAS, + INNER_PARAMETERS, and SPLINE_KNOTS. """ if mode == MODE_RES: raise ValueError( @@ -175,7 +176,6 @@ def __call__( RES: res, SRES: sres, RDATAS: rdatas, - X_INNER_OPT: self.inner_problem.get_inner_parameter_dictionary(), } # if any amici simulation failed, it's unlikely we can compute @@ -198,9 +198,7 @@ def __call__( inner_result[FVAL] = self.inner_solver.calculate_obj_function( x_inner_opt ) - inner_result[ - X_INNER_OPT - ] = self.inner_problem.get_inner_parameter_dictionary() + inner_result[SPLINE_KNOTS] = self.inner_problem.get_spline_knots() inner_result[ INNER_PARAMETERS diff --git a/pypesto/hierarchical/semiquantitative/problem.py b/pypesto/hierarchical/semiquantitative/problem.py index 65bd62bb7..1a9447553 100644 --- a/pypesto/hierarchical/semiquantitative/problem.py +++ b/pypesto/hierarchical/semiquantitative/problem.py @@ -205,6 +205,48 @@ def get_inner_parameter_dictionary(self) -> dict: inner_par_dict[x_id] = x.value return inner_par_dict + def get_spline_knots(self) -> np.ndarray: + """Get spline knots of all semiquantitative observables. + + Returns + ------- + list[list[list[float], list[float]]] + A list of lists of lists. Each list in the first list corresponds to a + semiquantitative observable. Each of these lists contains two lists: + the first list contains the spline bases, the second list contains the + spline knot values. The ordering of the observable lists is the same + as in `pypesto.problem.hierarchical.semiquant_observable_ids`. + """ + # We need the solver only for the rescaling function. + from .solver import SemiquantInnerSolver + + all_spline_knots = [] + + for group in self.get_groups_for_xs(InnerParameterType.SPLINE): + group_dict = self.groups[group] + n_spline_pars = group_dict[N_SPLINE_PARS] + n_data_points = group_dict[NUM_DATAPOINTS] + + inner_pars = np.array( + [x.value for x in self.get_xs_for_group(group)] + ) + + # Utility matrix for the spline knot calculation + lower_trian = np.tril(np.ones((n_spline_pars, n_spline_pars))) + knot_values = np.dot(lower_trian, inner_pars) + + _, knot_bases, _ = SemiquantInnerSolver._rescale_spline_bases( + sim_all=group_dict[CURRENT_SIMULATION], + N=n_spline_pars, + K=n_data_points, + ) + + spline_knots_for_observable = [knot_bases, knot_values] + + all_spline_knots.append(spline_knots_for_observable) + + return all_spline_knots + def get_measurements_for_group(self, gr) -> np.ndarray: """Get measurements for a group.""" # Taking the ixs of first inner parameter since diff --git a/pypesto/hierarchical/semiquantitative/solver.py b/pypesto/hierarchical/semiquantitative/solver.py index f3aa8b28b..1871b90d2 100644 --- a/pypesto/hierarchical/semiquantitative/solver.py +++ b/pypesto/hierarchical/semiquantitative/solver.py @@ -438,7 +438,8 @@ def inner_gradient_wrapper(x): return results - def _rescale_spline_bases(self, sim_all: np.ndarray, N: int, K: int): + @staticmethod + def _rescale_spline_bases(sim_all: np.ndarray, N: int, K: int): """Rescale the spline bases. Before the optimization of the spline parameters, we have to fix the diff --git a/pypesto/history/csv.py b/pypesto/history/csv.py index c672fe8fe..c96c0c306 100644 --- a/pypesto/history/csv.py +++ b/pypesto/history/csv.py @@ -20,7 +20,6 @@ RES, SRES, TIME, - X_INNER_OPT, ModeType, X, ) @@ -155,10 +154,6 @@ def _update_trace( else: row[(var, np.nan)] = np.nan - if X_INNER_OPT in result: - for x_inner_id, x_inner_opt_value in result[X_INNER_OPT].items(): - row[(X_INNER_OPT, x_inner_id)] = x_inner_opt_value - self._trace = pd.concat( (self._trace, pd.DataFrame([row])), ) diff --git a/pypesto/objective/amici/amici.py b/pypesto/objective/amici/amici.py index 13e2d87e8..b4ec58efb 100644 --- a/pypesto/objective/amici/amici.py +++ b/pypesto/objective/amici/amici.py @@ -14,6 +14,7 @@ MODE_FUN, MODE_RES, RDATAS, + SPLINE_KNOTS, SUFFIXES_CSV, SUFFIXES_HDF5, ModeType, @@ -232,8 +233,9 @@ def __init__( # `set_custom_timepoints` method for more information. self.custom_timepoints = None - # Initialize the dictionary for saving of inner parameters. + # Initialize the list for saving of inner parameter values. self.inner_parameters: list[float] = None + self.spline_knots: list[list[list[float]]] = None def get_config(self) -> dict: """Return basic information of the objective configuration.""" @@ -504,6 +506,9 @@ def call_unprocessed( if ret.get(INNER_PARAMETERS, None) is not None: self.inner_parameters = ret[INNER_PARAMETERS] + if ret.get(SPLINE_KNOTS, None) is not None: + self.spline_knots = ret[SPLINE_KNOTS] + # check whether we should update data for preequilibration guesses if ( self.guess_steadystate diff --git a/pypesto/optimize/optimizer.py b/pypesto/optimize/optimizer.py index 166020958..e16d01e9d 100644 --- a/pypesto/optimize/optimizer.py +++ b/pypesto/optimize/optimizer.py @@ -11,7 +11,7 @@ import numpy as np import scipy.optimize -from ..C import FVAL, GRAD, INNER_PARAMETERS, MODE_FUN, MODE_RES +from ..C import FVAL, GRAD, INNER_PARAMETERS, MODE_FUN, MODE_RES, SPLINE_KNOTS from ..history import HistoryOptions, NoHistory, OptimizerHistory from ..objective import Objective from ..problem import Problem @@ -68,6 +68,12 @@ def wrapped_minimize( ): result[INNER_PARAMETERS] = problem.objective.inner_parameters + if ( + hasattr(problem.objective, SPLINE_KNOTS) + and problem.objective.spline_knots is not None + ): + result[SPLINE_KNOTS] = problem.objective.spline_knots + return result return wrapped_minimize diff --git a/pypesto/problem/hierarchical.py b/pypesto/problem/hierarchical.py index 26061d691..bc8b2a2ff 100644 --- a/pypesto/problem/hierarchical.py +++ b/pypesto/problem/hierarchical.py @@ -34,6 +34,11 @@ class HierarchicalProblem(Problem): Only relevant if hierarchical is True. Contains the bounds of easily interpretable inner parameters only, e.g. noise parameters, scaling factors, offsets. + semiquant_observable_ids: + The ids of semiquantitative observables. Only relevant if hierarchical + is True. If not None, the optimization result's `spline_knots` will be + a list of lists of spline knots for each semiquantitative observable in + the order of these ids. """ def __init__( @@ -70,3 +75,7 @@ def __init__( self.inner_lb = np.array(inner_lb) self.inner_ub = np.array(inner_ub) + + self.semiquant_observable_ids = ( + self.objective.calculator.semiquant_observable_ids + ) diff --git a/pypesto/result/optimize.py b/pypesto/result/optimize.py index 8d2742ad9..975fb5007 100644 --- a/pypesto/result/optimize.py +++ b/pypesto/result/optimize.py @@ -117,6 +117,7 @@ def __init__( self.optimizer = optimizer self.free_indices = None self.inner_parameters = None + self.spline_knots = None def __getattr__(self, key): try: diff --git a/pypesto/visualize/spline_approximation.py b/pypesto/visualize/spline_approximation.py index 950cd6e77..bf0f24ddb 100644 --- a/pypesto/visualize/spline_approximation.py +++ b/pypesto/visualize/spline_approximation.py @@ -12,8 +12,10 @@ AMICI_Y, CURRENT_SIMULATION, DATAPOINTS, + EXPDATA_MASK, REGULARIZE_SPLINE, SCIPY_X, + SPLINE_KNOTS, ) from ..problem import Problem from ..result import Result @@ -26,6 +28,7 @@ from ..hierarchical.semiquantitative.solver import ( SemiquantInnerSolver, _calculate_regularization_for_group, + extract_expdata_using_mask, get_spline_mapped_simulations, ) except ImportError: @@ -61,6 +64,25 @@ def plot_splines_from_pypesto_result( 'The calculator must be an instance of the InnerCalculatorCollector.' ) + # Get the spline knot values from the pypesto result + spline_knot_values = [ + obs_spline_knots[1] + for obs_spline_knots in pypesto_result.optimize_result.list[ + start_index + ][SPLINE_KNOTS] + ] + + # Get inner parameters per observable as differences of spline knot values + inner_parameters = [ + np.concatenate([[obs_knot_values[0]], np.diff(obs_knot_values)]) + for obs_knot_values in spline_knot_values + ] + + inner_results = [ + {SCIPY_X: obs_inner_parameter} + for obs_inner_parameter in inner_parameters + ] + # Get the parameters from the pypesto result for the start_index. x_dct = dict( zip( @@ -107,7 +129,6 @@ def plot_splines_from_pypesto_result( # Get simulation and sigma. sim = [rdata[AMICI_Y] for rdata in inner_rdatas] - sigma = [rdata[AMICI_SIGMAY] for rdata in inner_rdatas] spline_calculator = None for ( @@ -117,14 +138,23 @@ def plot_splines_from_pypesto_result( spline_calculator = calculator break + if spline_calculator is None: + raise ValueError( + 'No SemiquantCalculator found in the inner_calculators of the objective.' + 'Cannot plot splines.' + ) + # Get the inner solver and problem. inner_solver = spline_calculator.inner_solver inner_problem = spline_calculator.inner_problem - inner_results = inner_solver.solve(inner_problem, sim, sigma) - return plot_splines_from_inner_result( - inner_problem, inner_solver, inner_results, observable_ids, **kwargs + inner_problem, + inner_solver, + inner_results, + sim, + observable_ids, + **kwargs, ) @@ -132,6 +162,7 @@ def plot_splines_from_inner_result( inner_problem: 'pypesto.hierarchical.spline_approximation.problem.SplineInnerProblem', inner_solver: 'pypesto.hierarchical.spline_approximation.solver.SplineInnerSolver', results: list[dict], + sim: list[np.ndarray], observable_ids=None, **kwargs, ): @@ -145,6 +176,10 @@ def plot_splines_from_inner_result( The inner solver. results: The results from the inner solver. + sim: + The simulated model output. + observable_ids: + The ids of the observables. kwargs: Additional arguments to pass to the plotting function. @@ -192,7 +227,9 @@ def plot_splines_from_inner_result( spline_knots = np.dot(lower_trian, s) measurements = inner_problem.groups[group][DATAPOINTS] - simulation = inner_problem.groups[group][CURRENT_SIMULATION] + simulation = extract_expdata_using_mask( + expdata=sim, mask=inner_problem.groups[group][EXPDATA_MASK] + ) # For the simulation, get the spline bases ( @@ -200,7 +237,6 @@ def plot_splines_from_inner_result( spline_bases, n, ) = SemiquantInnerSolver._rescale_spline_bases( - self=None, sim_all=simulation, N=len(spline_knots), K=len(simulation), @@ -402,7 +438,6 @@ def _add_spline_mapped_simulations_to_model_fit( spline_bases, n, ) = SemiquantInnerSolver._rescale_spline_bases( - self=None, sim_all=simulation, N=len(s), K=len(simulation), @@ -543,7 +578,6 @@ def _obtain_regularization_for_start( spline_bases, _, ) = SemiquantInnerSolver._rescale_spline_bases( - self=None, sim_all=simulation, N=len(s), K=len(simulation), From 6d24316b539a1c1f9f4d3bb9260e84e3b453281d Mon Sep 17 00:00:00 2001 From: Doresic Date: Thu, 29 Feb 2024 17:29:49 +0100 Subject: [PATCH 2/6] Documentation update & fix notebook --- doc/example/semiquantitative_data.ipynb | 7 +++++-- pypesto/hierarchical/semiquantitative/problem.py | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/doc/example/semiquantitative_data.ipynb b/doc/example/semiquantitative_data.ipynb index e4e683032..29400f562 100644 --- a/doc/example/semiquantitative_data.ipynb +++ b/doc/example/semiquantitative_data.ipynb @@ -431,7 +431,10 @@ " )\n", "\n", " plot_splines_from_inner_result(\n", - " inner_problem, inner_solvers[minimal_diff], results[minimal_diff]\n", + " inner_problem,\n", + " inner_solvers[minimal_diff],\n", + " results[minimal_diff],\n", + " sim=[simulation],\n", " )\n", " plt.show()" ] @@ -467,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.10" }, "vscode": { "interpreter": { diff --git a/pypesto/hierarchical/semiquantitative/problem.py b/pypesto/hierarchical/semiquantitative/problem.py index 1a9447553..52da466e7 100644 --- a/pypesto/hierarchical/semiquantitative/problem.py +++ b/pypesto/hierarchical/semiquantitative/problem.py @@ -205,15 +205,15 @@ def get_inner_parameter_dictionary(self) -> dict: inner_par_dict[x_id] = x.value return inner_par_dict - def get_spline_knots(self) -> np.ndarray: + def get_spline_knots(self) -> list[list[np.array[float], np.array[float]]]: """Get spline knots of all semiquantitative observables. Returns ------- - list[list[list[float], list[float]]] - A list of lists of lists. Each list in the first list corresponds to a - semiquantitative observable. Each of these lists contains two lists: - the first list contains the spline bases, the second list contains the + list[list[np.array[float], np.array[float]]] + A list of lists with two arrays. Each list in the first level corresponds + to a semiquantitative observable. Each of these lists contains two arrays: + the first array contains the spline bases, the second array contains the spline knot values. The ordering of the observable lists is the same as in `pypesto.problem.hierarchical.semiquant_observable_ids`. """ From 28225a07f311769d9a4e43b0df12539b4f9cd36a Mon Sep 17 00:00:00 2001 From: Doresic Date: Thu, 29 Feb 2024 17:40:38 +0100 Subject: [PATCH 3/6] Fix types --- pypesto/hierarchical/semiquantitative/problem.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pypesto/hierarchical/semiquantitative/problem.py b/pypesto/hierarchical/semiquantitative/problem.py index 52da466e7..359f92f5c 100644 --- a/pypesto/hierarchical/semiquantitative/problem.py +++ b/pypesto/hierarchical/semiquantitative/problem.py @@ -205,12 +205,14 @@ def get_inner_parameter_dictionary(self) -> dict: inner_par_dict[x_id] = x.value return inner_par_dict - def get_spline_knots(self) -> list[list[np.array[float], np.array[float]]]: + def get_spline_knots( + self, + ) -> list[list[np.ndarray[float], np.ndarray[float]]]: """Get spline knots of all semiquantitative observables. Returns ------- - list[list[np.array[float], np.array[float]]] + list[list[np.ndarray[float], np.ndarray[float]]] A list of lists with two arrays. Each list in the first level corresponds to a semiquantitative observable. Each of these lists contains two arrays: the first array contains the spline bases, the second array contains the From 47a9647ca3aa36ff34f2b232a44822f57252cb90 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Mon, 4 Mar 2024 16:29:34 +0100 Subject: [PATCH 4/6] Daniel review change Co-authored-by: Daniel Weindl --- pypesto/visualize/spline_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypesto/visualize/spline_approximation.py b/pypesto/visualize/spline_approximation.py index bf0f24ddb..17dc409b4 100644 --- a/pypesto/visualize/spline_approximation.py +++ b/pypesto/visualize/spline_approximation.py @@ -140,7 +140,7 @@ def plot_splines_from_pypesto_result( if spline_calculator is None: raise ValueError( - 'No SemiquantCalculator found in the inner_calculators of the objective.' + 'No SemiquantCalculator found in the inner_calculators of the objective. ' 'Cannot plot splines.' ) From aeca196a74eedc731e905fd68dd475bf831f0f96 Mon Sep 17 00:00:00 2001 From: Doresic Date: Wed, 6 Mar 2024 14:48:48 +0100 Subject: [PATCH 5/6] add test spline knot saving --- test/hierarchical/test_spline.py | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/test/hierarchical/test_spline.py b/test/hierarchical/test_spline.py index 3283c1a1f..1d99f05e0 100644 --- a/test/hierarchical/test_spline.py +++ b/test/hierarchical/test_spline.py @@ -8,11 +8,13 @@ import pypesto.logging import pypesto.optimize import pypesto.petab +import pypesto.store from pypesto.C import ( INNER_NOISE_PARS, LIN, MODE_FUN, OPTIMIZE_NOISE, + SPLINE_KNOTS, InnerParameterType, ) from pypesto.hierarchical.semiquantitative import ( @@ -462,3 +464,50 @@ def test_calculate_regularization_for_group(): regularization_gradient, expected_regularization_gradient, ) + + +def test_save_and_load_spline_knots(): + """Test the saving and loading of spline knots in an optimization result.""" + # Run optimization + petab_problem = petab.Problem.from_yaml(example_semiquantitative_yaml) + importer = pypesto.petab.PetabImporter( + petab_problem, + hierarchical=True, + ) + objective = importer.create_objective() + problem = importer.create_problem(objective) + + optimizer = pypesto.optimize.ScipyOptimizer( + method="L-BFGS-B", + options={"disp": None, "ftol": 2.220446049250313e-09, "gtol": 1e-5}, + ) + # Set seed for reproducibility. + np.random.seed(0) + result = pypesto.optimize.minimize( + problem=problem, n_starts=2, optimizer=optimizer + ) + + # Get spline knots + spline_knots_before = [ + result.optimize_result.list[i][SPLINE_KNOTS] for i in range(2) + ] + pypesto.store.write_result( + result=result, + filename="test_spline_knots.hdf5", + ) + # Load spline knots + result_loaded = pypesto.store.read_result("test_spline_knots.hdf5") + spline_knots_after = [ + result_loaded.optimize_result.list[i][SPLINE_KNOTS] for i in range(2) + ] + # Check that the loaded spline knots are the same as the original ones + assert np.all( + [ + np.allclose(knots_before, knots_after) + for knots_before, knots_after in zip( + spline_knots_before, spline_knots_after + ) + ] + ) + # Clean up + Path("test_spline_knots.hdf5").unlink() From 6c52eb2b2ae201181b1751d0069f1c0d37d4bb5e Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:20:29 +0100 Subject: [PATCH 6/6] Update pypesto/visualize/spline_approximation.py Co-authored-by: Maren Philipps <55318391+m-philipps@users.noreply.github.com> --- pypesto/visualize/spline_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypesto/visualize/spline_approximation.py b/pypesto/visualize/spline_approximation.py index 8977e7d31..51ca0e434 100644 --- a/pypesto/visualize/spline_approximation.py +++ b/pypesto/visualize/spline_approximation.py @@ -129,7 +129,7 @@ def plot_splines_from_pypesto_result( ) return None - # Get simulation and sigma. + # Get simulation. sim = [rdata[AMICI_Y] for rdata in inner_rdatas] spline_calculator = None