diff --git a/pypesto/profile/options.py b/pypesto/profile/options.py index 21cd0b858..f3784db7d 100644 --- a/pypesto/profile/options.py +++ b/pypesto/profile/options.py @@ -66,6 +66,8 @@ def __init__( self.magic_factor_obj_value = magic_factor_obj_value self.whole_path = whole_path + self.validate() + def __getattr__(self, key): """Allow usage of keys like attributes.""" try: @@ -91,3 +93,24 @@ def create_instance( return maybe_options options = ProfileOptions(**maybe_options) return options + + def validate(self): + """Check if options are valid. + + Raises ``ValueError`` if current settings aren't valid. + """ + if self.min_step_size <= 0: + raise ValueError("min_step_size must be > 0.") + if self.max_step_size <= 0: + raise ValueError("max_step_size must be > 0.") + if self.min_step_size > self.max_step_size: + raise ValueError("min_step_size must be <= max_step_size.") + if self.default_step_size <= 0: + raise ValueError("default_step_size must be > 0.") + if self.default_step_size > self.max_step_size: + raise ValueError("default_step_size must be <= max_step_size.") + if self.default_step_size < self.min_step_size: + raise ValueError("default_step_size must be >= min_step_size.") + + if self.magic_factor_obj_value < 0 or self.magic_factor_obj_value >= 1: + raise ValueError("magic_factor_obj_value must be >= 0 and < 1.") diff --git a/pypesto/profile/profile.py b/pypesto/profile/profile.py index bdbaf7ab7..116ef21d8 100644 --- a/pypesto/profile/profile.py +++ b/pypesto/profile/profile.py @@ -89,6 +89,7 @@ def parameter_profile( if profile_options is None: profile_options = ProfileOptions() profile_options = ProfileOptions.create_instance(profile_options) + profile_options.validate() # create a function handle that will be called later to get the next point if isinstance(next_guess_method, str): diff --git a/test/profile/test_profile.py b/test/profile/test_profile.py index 855af1ada..133c2f3ab 100644 --- a/test/profile/test_profile.py +++ b/test/profile/test_profile.py @@ -7,6 +7,7 @@ from copy import deepcopy import numpy as np +import pytest from numpy.testing import assert_almost_equal import pypesto @@ -412,3 +413,25 @@ def test_approximate_ci(): # bound value assert np.isclose(lb, -3) assert np.isclose(ub, 9) + + +def test_options_valid(): + """Test ProfileOptions validity checks.""" + # default settings are valid + profile.ProfileOptions() + + # try to set invalid values + with pytest.raises(ValueError): + profile.ProfileOptions(default_step_size=-1) + with pytest.raises(ValueError): + profile.ProfileOptions(default_step_size=1, min_step_size=2) + with pytest.raises(ValueError): + profile.ProfileOptions( + default_step_size=2, + min_step_size=1, + ) + with pytest.raises(ValueError): + profile.ProfileOptions( + min_step_size=2, + max_step_size=1, + )