diff --git a/.flake8 b/.flake8 index 9e1edba0a..9a183dbd0 100644 --- a/.flake8 +++ b/.flake8 @@ -9,6 +9,14 @@ extend-ignore = E203 # Don't be crazy if line too long E501 + # Missing docstring in public module + D100 + # Missing docstring in public method + # D102 + # Missing docstring in magic method + D105 + # Missing docstring in __init__ + D107 per-file-ignores = # Imported but unused @@ -20,7 +28,3 @@ per-file-ignores = pypesto/problem.py:D400,D205,D107 pypesto/util.py:D400,D205,D107 pypesto/C.py:D400,D205,D107 - pypesto/version.py:D - # ignore D100='Missing docstring in public module', - # D105='Missing docstring in magic method' and D107='Missing docstring in __init__'. - *:D100,D105,D107 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07407e658..520068b01 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: run: .github/workflows/install_deps.sh amici - name: Run tests - timeout-minutes: 12 + timeout-minutes: 15 run: tox -e base - name: Coverage @@ -176,7 +176,7 @@ jobs: run: tox -e size - name: Run quality checks - timeout-minutes: 1 + timeout-minutes: 5 run: tox -e project,flake8 - name: Run pre-commit hooks diff --git a/pypesto/C.py b/pypesto/C.py index f22e974cd..e30835811 100644 --- a/pypesto/C.py +++ b/pypesto/C.py @@ -34,8 +34,21 @@ X0 = 'x0' ID = 'id' -EXITFLAG = 'exitflag' -MESSAGE = 'message' + +############################################################################### +# HISTORY + +HISTORY = "history" +TRACE = "trace" +N_ITERATIONS = "n_iterations" +MESSAGES = "messages" +MESSAGE = "message" +EXITFLAG = "exitflag" +TRACE_SAVE_ITER = "trace_save_iter" + +SUFFIXES_CSV = ["csv"] +SUFFIXES_HDF5 = ["hdf5", "h5"] +SUFFIXES = SUFFIXES_CSV + SUFFIXES_HDF5 ############################################################################### @@ -166,7 +179,7 @@ class EnsembleType(Enum): ############################################################################### -# Environment variables +# ENVIRONMENT VARIABLES PYPESTO_MAX_N_STARTS: str = "PYPESTO_MAX_N_STARTS" PYPESTO_MAX_N_SAMPLES: str = "PYPESTO_MAX_N_SAMPLES" diff --git a/pypesto/objective/__init__.py b/pypesto/objective/__init__.py index 3f5d172b1..5bfc53840 100644 --- a/pypesto/objective/__init__.py +++ b/pypesto/objective/__init__.py @@ -11,10 +11,12 @@ from .function import Objective from .history import ( CsvHistory, + CsvHistoryTemplateError, Hdf5History, History, HistoryBase, HistoryOptions, + HistoryTypeError, MemoryHistory, OptimizerHistory, ) diff --git a/pypesto/objective/base.py b/pypesto/objective/base.py index a55d54ac6..b4a1a80dd 100644 --- a/pypesto/objective/base.py +++ b/pypesto/objective/base.py @@ -226,7 +226,6 @@ def call_unprocessed( result: A dict containing the results. """ - raise NotImplementedError() def check_mode(self, mode: ModeType) -> bool: """ @@ -405,15 +404,13 @@ def update_from_problem( Vector of the same length as x_fixed_indices, containing the values of the fixed parameters. """ - pre_post_processor = FixedParametersProcessor( + self.pre_post_processor = FixedParametersProcessor( dim_full=dim_full, x_free_indices=x_free_indices, x_fixed_indices=x_fixed_indices, x_fixed_vals=x_fixed_vals, ) - self.pre_post_processor = pre_post_processor - def check_grad_multi_eps( self, *args, diff --git a/pypesto/objective/history.py b/pypesto/objective/history.py index 3b81b1113..c9dcfed2c 100644 --- a/pypesto/objective/history.py +++ b/pypesto/objective/history.py @@ -1,5 +1,6 @@ import abc import copy +import logging import numbers import os import time @@ -12,26 +13,37 @@ from ..C import ( CHI2, + EXITFLAG, FVAL, GRAD, HESS, + HISTORY, + MESSAGE, + MESSAGES, MODE_FUN, MODE_RES, N_FVAL, N_GRAD, N_HESS, + N_ITERATIONS, N_RES, N_SRES, RES, SCHI2, SRES, + SUFFIXES, + SUFFIXES_CSV, + SUFFIXES_HDF5, TIME, + TRACE, + TRACE_SAVE_ITER, ModeType, X, ) +from ..util import allclose, is_none_or_nan, is_none_or_nan_array, isclose from .util import ( + chi2_to_fval, res_to_chi2, - res_to_fval, schi2_to_grad, sres_to_fim, sres_to_schi2, @@ -40,12 +52,33 @@ ResultDict = Dict[str, Union[float, np.ndarray]] MaybeArray = Union[np.ndarray, 'np.nan'] +logger = logging.getLogger(__name__) + + +class HistoryTypeError(ValueError): + """Error raised when an unsupported history type is requested.""" + + def __init__(self, history_type: str): + super().__init__( + f"Unsupported history type: {history_type}, expected {SUFFIXES}" + ) + + +class CsvHistoryTemplateError(ValueError): + """Error raised when no template is given for CSV history.""" + + def __init__(self, storage_file: str): + super().__init__( + "CSV History requires an `{id}` template in the `storage_file`, " + f"but is {storage_file}" + ) + def trace_wrap(f): """ Wrap around trace getters. - Transform input `ix` vectors to a valid index list, and reduces for + Transform input `ix` vectors to a valid index list, and reduce for integer `ix` the output to a single value. """ @@ -75,7 +108,7 @@ def wrapped_f( class HistoryOptions(dict): """ - Options for the objective that are used in optimization. + Options for what values to record. In addition implements a factory pattern to generate history objects. @@ -152,21 +185,15 @@ def _sanity_check(self): return # extract storage type - type_ = Path(self.storage_file).suffix + suffix = Path(self.storage_file).suffix[1:] # check storage format is valid - if type_ not in [".csv", ".hdf5", ".h5"]: - raise ValueError( - "Only history storage to '.csv' and '.hdf5' is supported, got " - f"{type_}", - ) + if suffix not in SUFFIXES: + raise HistoryTypeError(suffix) # check csv histories are parametrized - if type_ == ".csv" and "{id}" not in self.storage_file: - raise ValueError( - "For csv history, the `storage_file` must contain an `{id}` " - "template" - ) + if suffix in SUFFIXES_CSV and "{id}" not in self.storage_file: + raise CsvHistoryTemplateError(self.storage_file) @staticmethod def assert_instance( @@ -188,8 +215,8 @@ def create_history( self, id: str, x_names: Sequence[str], - ) -> 'History': - """Create a :class:`History` object; Factory method. + ) -> 'HistoryBase': + """Create a :class:`HistoryBase` object; Factory method. Parameters ---------- @@ -205,29 +232,35 @@ def create_history( else: return History(options=self) + # replace id template in storage file storage_file = self.storage_file.replace("{id}", id) - _, type_ = os.path.splitext(storage_file) + # evaluate type + suffix = Path(storage_file).suffix[1:] - if type_ == '.csv': + # create history type based on storage type + if suffix in SUFFIXES_CSV: return CsvHistory(x_names=x_names, file=storage_file, options=self) - elif type_ in ['.hdf5', '.h5']: + elif suffix in SUFFIXES_HDF5: return Hdf5History(id=id, file=storage_file, options=self) else: - raise ValueError( - "Only history storage to '.csv' and '.hdf5' is supported, got " - f"{type_}", - ) + raise HistoryTypeError(suffix) class HistoryBase(abc.ABC): - """Abstract base class for history objects. + """Base class for history objects. - Can be used as a dummy history, but does not implement any history - functionality. + Can be used as a dummy history, but does not implement any functionality. """ - def __len__(self): + # values calculated by the objective function + RESULT_KEYS = (FVAL, GRAD, HESS, RES, SRES) + # history also knows chi2, schi2 + FULL_RESULT_KEYS = (*RESULT_KEYS, CHI2, SCHI2) + # all possible history entries + ALL_KEYS = (X, *FULL_RESULT_KEYS, TIME) + + def __len__(self) -> int: """Define length by number of stored entries in the history.""" raise NotImplementedError() @@ -258,9 +291,9 @@ def finalize( self, message: str = None, exitflag: str = None, - ): + ) -> None: """ - Finalize history. Called after a run. + Finalize history. Called after a run. Default: Do nothing. Parameters ---------- @@ -417,7 +450,7 @@ def get_time_trace( """ raise NotImplementedError() - def get_trimmed_indices(self): + def get_trimmed_indices(self) -> np.ndarray: """Get indices for a monotonically decreasing history.""" fval_trace = self.get_fval_trace() return np.where(fval_trace <= np.fmin.accumulate(fval_trace))[0] @@ -425,7 +458,7 @@ def get_trimmed_indices(self): class History(HistoryBase): """ - Track number of function evaluations only, no trace. + Tracks number of function evaluations only, no trace. Parameters ---------- @@ -467,16 +500,8 @@ def update( The objective function values for parameters `x`, sensitivities `sensi_orders` and mode `mode`. """ - res = result.get(RES, None) - if res is not None and FVAL not in result: - # no option trace_record_fval - result[FVAL] = res_to_fval(res) self._update_counts(sensi_orders, mode) - def finalize(self, message: str = None, exitflag: str = None): - """See `HistoryBase` docstring.""" - pass - def _update_counts( self, sensi_orders: Tuple[int, ...], @@ -531,7 +556,7 @@ class MemoryHistory(History): """ Class for optimization history stored in memory. - Track number of function evaluations and keeps an in-memory + Tracks number of function evaluations and keeps an in-memory trace of function evaluations. Parameters @@ -542,10 +567,9 @@ class MemoryHistory(History): def __init__(self, options: Union[HistoryOptions, Dict] = None): super().__init__(options=options) - self._trace_keys = {X, FVAL, GRAD, HESS, RES, SRES, CHI2, SCHI2, TIME} - self._trace: Dict[str, Any] = {key: [] for key in self._trace_keys} + self._trace: Dict[str, Any] = {key: [] for key in HistoryBase.ALL_KEYS} - def __len__(self): + def __len__(self) -> int: """Define length of history object.""" return len(self._trace[TIME]) @@ -562,12 +586,19 @@ def update( def _update_trace(self, x, mode, result): """Update internal trace representation.""" - ret = extract_values(mode, result, self.options) - for key in self._trace_keys - {X, TIME}: - self._trace[key].append(ret[key]) + # calculating function values from residuals + # and reduce via requested history options + result: dict = reduce_result_via_options( + add_fun_from_res(result), self.options + ) + + result[X] = x + used_time = time.time() - self._start_time - self._trace[X].append(x) - self._trace[TIME].append(used_time) + result[TIME] = used_time + + for key in HistoryBase.ALL_KEYS: + self._trace[key].append(result[key]) @trace_wrap def get_x_trace( @@ -699,16 +730,16 @@ def __init__( self.x_names = trace[X].columns self._update_counts_from_trace() - def __len__(self): + def __len__(self) -> int: """Define length of history object.""" return len(self._trace) - def _update_counts_from_trace(self): - self._n_fval = self._trace[('n_fval', np.NaN)].max() - self._n_grad = self._trace[('n_grad', np.NaN)].max() - self._n_hess = self._trace[('n_hess', np.NaN)].max() - self._n_res = self._trace[('n_res', np.NaN)].max() - self._n_sres = self._trace[('n_sres', np.NaN)].max() + def _update_counts_from_trace(self) -> None: + self._n_fval = self._trace[(N_FVAL, np.NaN)].max() + self._n_grad = self._trace[(N_GRAD, np.NaN)].max() + self._n_hess = self._trace[(N_HESS, np.NaN)].max() + self._n_res = self._trace[(N_RES, np.NaN)].max() + self._n_sres = self._trace[(N_SRES, np.NaN)].max() def update( self, @@ -740,8 +771,11 @@ def _update_trace( if self._trace is None: self._init_trace(x) - # extract function values - ret = extract_values(mode, result, self.options) + # calculating function values from residuals + # and reduce via requested history options + result = reduce_result_via_options( + add_fun_from_res(result), self.options + ) used_time = time.time() - self._start_time @@ -757,21 +791,25 @@ def _update_trace( N_HESS: self._n_hess, N_RES: self._n_res, N_SRES: self._n_sres, - FVAL: ret[FVAL], - RES: ret[RES], - SRES: ret[SRES], - CHI2: ret[CHI2], - HESS: ret[HESS], + FVAL: result[FVAL], + RES: result[RES], + SRES: result[SRES], + CHI2: result[CHI2], + HESS: result[HESS], } for var, val in values.items(): - row[(var, float('nan'))] = val + row[(var, np.nan)] = val - for var, val in {X: x, GRAD: ret[GRAD], SCHI2: ret[SCHI2]}.items(): + for var, val in { + X: x, + GRAD: result[GRAD], + SCHI2: result[SCHI2], + }.items(): if var == X or self.options[f'trace_record_{var}']: row[var] = val else: - row[(var, float('nan'))] = np.NaN + row[(var, np.nan)] = np.nan self._trace = pd.concat( (self._trace, pd.DataFrame([row])), @@ -786,7 +824,7 @@ def _init_trace(self, x: np.ndarray): self.x_names = [f'x{i}' for i, _ in enumerate(x)] columns: List[Tuple] = [ - (c, float('nan')) + (c, np.nan) for c in [ TIME, N_FVAL, @@ -803,7 +841,7 @@ def _init_trace(self, x: np.ndarray): ] for var in [X, GRAD, SCHI2]: - if var == 'x' or self.options[f'trace_record_{var}']: + if var == X or self.options[f'trace_record_{var}']: columns.extend([(var, x_name) for x_name in self.x_names]) else: columns.extend([(var,)]) @@ -826,7 +864,7 @@ def _init_trace(self, x: np.ndarray): } for var, dtype in trace_dtypes.items(): - self._trace[(var, np.NaN)] = self._trace[(var, np.NaN)].astype( + self._trace[(var, np.nan)] = self._trace[(var, np.nan)].astype( dtype ) @@ -846,7 +884,7 @@ def _save_trace(self, finalize: bool = False): ): # save trace_copy = copy.deepcopy(self._trace) - for field in [('hess', np.NaN), ('res', np.NaN), ('sres', np.NaN)]: + for field in [(HESS, np.nan), (RES, np.nan), (SRES, np.nan)]: trace_copy[field] = trace_copy[field].apply( ndarray2string_full ) @@ -936,14 +974,15 @@ def __init__( self, id: str, file: str, options: Union[HistoryOptions, Dict] = None ): super().__init__(options=options) - self.id = id - self.file, self.editable = self._check_file_id(file) + self.id: str = id + self.file: str = file + self.editable: bool = self._is_editable(file) self._generate_hdf5_group() - def __len__(self): + def __len__(self) -> int: """Define length of history object.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_iterations'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_ITERATIONS] def update( self, @@ -961,42 +1000,42 @@ def update( super().update(x, sensi_orders, mode, result) self._update_trace(x, sensi_orders, mode, result) - def get_history_directory(self): - """Return filepath.""" - return self.file - - def finalize(self, message: str = None, exitflag: str = None): + def finalize(self, message: str = None, exitflag: str = None) -> None: """See `HistoryBase` docstring.""" super().finalize() + + # add message and exitflag to trace with h5py.File(self.file, 'a') as f: - if f'history/{self.id}/messages/' not in f: - f.create_group(f'history/{self.id}/messages/') - grp = f[f'history/{self.id}/messages/'] + if f'{HISTORY}/{self.id}/{MESSAGES}/' not in f: + f.create_group(f'{HISTORY}/{self.id}/{MESSAGES}/') + grp = f[f'{HISTORY}/{self.id}/{MESSAGES}/'] if message is not None: - grp.attrs['message'] = message + grp.attrs[MESSAGE] = message if exitflag is not None: - grp.attrs['exitflag'] = exitflag + grp.attrs[EXITFLAG] = exitflag @staticmethod - def load(id: str, file: str): + def load( + id: str, file: str, options: Union[HistoryOptions, Dict] = None + ) -> 'Hdf5History': """Load the History object from memory.""" - loaded_h5history = Hdf5History(id, file) - loaded_h5history.recover_options(file) - return loaded_h5history + history = Hdf5History(id=id, file=file, options=options) + if options is None: + history.recover_options(file) + return history def recover_options(self, file: str): """Recover options when loading the hdf5 history from memory. Done by testing which entries were recorded. """ - trace_record = self._check_for_not_nan_entries(X) - trace_record_grad = self._check_for_not_nan_entries(GRAD) - trace_record_hess = self._check_for_not_nan_entries(HESS) - trace_record_res = self._check_for_not_nan_entries(RES) - trace_record_sres = self._check_for_not_nan_entries(SRES) - trace_record_chi2 = self._check_for_not_nan_entries(CHI2) - trace_record_schi2 = self._check_for_not_nan_entries(SCHI2) - storage_file = file + trace_record = self._has_non_nan_entries(X) + trace_record_grad = self._has_non_nan_entries(GRAD) + trace_record_hess = self._has_non_nan_entries(HESS) + trace_record_res = self._has_non_nan_entries(RES) + trace_record_sres = self._has_non_nan_entries(SRES) + trace_record_chi2 = self._has_non_nan_entries(CHI2) + trace_record_schi2 = self._has_non_nan_entries(SCHI2) restored_history_options = HistoryOptions( trace_record=trace_record, @@ -1007,13 +1046,13 @@ def recover_options(self, file: str): trace_record_chi2=trace_record_chi2, trace_record_schi2=trace_record_schi2, trace_save_iter=self.trace_save_iter, - storage_file=storage_file, + storage_file=file, ) self.options = restored_history_options - def _check_for_not_nan_entries(self, hdf5_group: str) -> bool: - """Check if there exist not-nan entries stored for a given group.""" + def _has_non_nan_entries(self, hdf5_group: str) -> bool: + """Check if there exist non-nan entries stored for a given group.""" group = self._get_hdf5_entries(hdf5_group, ix=None) for entry in group: @@ -1029,68 +1068,68 @@ def _update_counts(self, sensi_orders: Tuple[int, ...], mode: ModeType): if mode == MODE_FUN: if 0 in sensi_orders: - f[f'history/{self.id}/trace/'].attrs['n_fval'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_FVAL] += 1 if 1 in sensi_orders: - f[f'history/{self.id}/trace/'].attrs['n_grad'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_GRAD] += 1 if 2 in sensi_orders: - f[f'history/{self.id}/trace/'].attrs['n_hess'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_HESS] += 1 elif mode == MODE_RES: if 0 in sensi_orders: - f[f'history/{self.id}/trace/'].attrs['n_res'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_RES] += 1 if 1 in sensi_orders: - f[f'history/{self.id}/trace/'].attrs['n_sres'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_SRES] += 1 @property def n_fval(self) -> int: """See `HistoryBase` docstring.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_fval'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_FVAL] @property def n_grad(self) -> int: """See `HistoryBase` docstring.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_grad'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_GRAD] @property def n_hess(self) -> int: """See `HistoryBase` docstring.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_hess'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_HESS] @property def n_res(self) -> int: """See `HistoryBase` docstring.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_res'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_RES] @property def n_sres(self) -> int: """See `HistoryBase` docstring.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['n_sres'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_SRES] @property - def trace_save_iter(self): + def trace_save_iter(self) -> int: """After how many iterations to store the trace.""" with h5py.File(self.file, 'r') as f: - return f[f'history/{self.id}/trace/'].attrs['trace_save_iter'] + return f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[TRACE_SAVE_ITER] @property - def message(self): + def message(self) -> str: """Optimizer message in case of finished optimization.""" with h5py.File(self.file, 'r') as f: try: - return f[f'history/{self.id}/messages/'].attrs['message'] + return f[f'{HISTORY}/{self.id}/{MESSAGES}/'].attrs[MESSAGE] except KeyError: return None @property - def exitflag(self): + def exitflag(self) -> str: """Optimizer exitflag in case of finished optimization.""" with h5py.File(self.file, 'r') as f: try: - return f[f'history/{self.id}/messages/'].attrs['exitflag'] + return f[f'{HISTORY}/{self.id}/{MESSAGES}/'].attrs[EXITFLAG] except KeyError: return None @@ -1100,53 +1139,57 @@ def _update_trace( sensi_orders: Tuple[int], mode: ModeType, result: ResultDict, - ): + ) -> None: """Update and possibly store the trace.""" if not self.options.trace_record: return - # extract function values - ret = extract_values(mode, result, self.options) + # calculating function values from residuals + # and reduce via requested history options + result = reduce_result_via_options( + add_fun_from_res(result), self.options + ) used_time = time.time() - self._start_time values = { - TIME: used_time, X: x, - FVAL: ret[FVAL], - GRAD: ret[GRAD], - RES: ret[RES], - SRES: ret[SRES], - CHI2: ret[CHI2], - SCHI2: ret[SCHI2], - HESS: ret[HESS], + FVAL: result[FVAL], + GRAD: result[GRAD], + RES: result[RES], + SRES: result[SRES], + CHI2: result[CHI2], + SCHI2: result[SCHI2], + HESS: result[HESS], + TIME: used_time, } with h5py.File(self.file, 'a') as f: - - iteration = f[f'history/{self.id}/trace/'].attrs['n_iterations'] + iteration = f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_ITERATIONS] for key in values.keys(): if values[key] is not None: f[ - f'history/{self.id}/trace/' f'{str(iteration)}/{key}' + f'{HISTORY}/{self.id}/{TRACE}/{iteration}/{key}' ] = values[key] - f[f'history/{self.id}/trace/'].attrs['n_iterations'] += 1 + f[f'{HISTORY}/{self.id}/{TRACE}/'].attrs[N_ITERATIONS] += 1 - def _generate_hdf5_group(self, f: h5py.File = None): - """Generate the group in the hdf5 file, if it does not exist yet.""" + def _generate_hdf5_group(self, f: h5py.File = None) -> None: + """Generate group in the hdf5 file, if it does not exist yet.""" try: with h5py.File(self.file, 'a') as f: - if f'history/{self.id}/trace/' not in f: - grp = f.create_group(f'history/{self.id}/trace/') - grp.attrs['n_iterations'] = 0 - grp.attrs['n_fval'] = 0 - grp.attrs['n_grad'] = 0 - grp.attrs['n_hess'] = 0 - grp.attrs['n_res'] = 0 - grp.attrs['n_sres'] = 0 - grp.attrs['trace_save_iter'] = self.options.trace_save_iter + if f'{HISTORY}/{self.id}/{TRACE}/' not in f: + grp = f.create_group(f'{HISTORY}/{self.id}/{TRACE}/') + grp.attrs[N_ITERATIONS] = 0 + grp.attrs[N_FVAL] = 0 + grp.attrs[N_GRAD] = 0 + grp.attrs[N_HESS] = 0 + grp.attrs[N_RES] = 0 + grp.attrs[N_SRES] = 0 + # TODO Y it makes no sense to save this here + # Also, we do not seem to evaluate this at all + grp.attrs[TRACE_SAVE_ITER] = self.options.trace_save_iter except OSError: pass @@ -1175,11 +1218,10 @@ def _get_hdf5_entries( trace_result = [] with h5py.File(self.file, 'r') as f: - for iteration in ix: try: dataset = f[ - f'history/{self.id}/trace/{str(iteration)}/{entry_id}' + f'{HISTORY}/{self.id}/{TRACE}/{iteration}/{entry_id}' ] if dataset.shape == (): entry = dataset[()] # scalar @@ -1254,9 +1296,9 @@ def get_time_trace( """See `HistoryBase` docstring.""" return self._get_hdf5_entries(TIME, ix) - def _check_file_id(self, file: str): + def _is_editable(self, file: str) -> bool: """ - Check, whether the id is already existent in the file. + Check whether the id is already existent in the file. Parameters ---------- @@ -1265,27 +1307,25 @@ def _check_file_id(self, file: str): Returns ------- - file: - HDF5 file name. editable: - Boolean, whether this hdf5 file should be editable. Returns - false if the history is a loaded one to prevent overwriting. - + Boolean, whether this hdf5 file should be editable. + Returns true if the file or the id entry does not exist yet. """ try: - with h5py.File(file, 'r') as f: - return file, ( - 'history' not in f.keys() or self.id not in f['history'] - ) - except OSError: # if the file is non-existent, return editable = True - return file, True + with h5py.File(file, 'a') as f: + # editable if the id entry does not exist + return 'history' not in f.keys() or self.id not in f['history'] + except OSError: + # editable if the file does not exist + return True class OptimizerHistory: """ - Objective call history. + Optimizer objective call history. - Container around a History object, which keeps track of optimal values. + Container around a History object, additionally keeping track of optimal + values. Attributes ---------- @@ -1318,6 +1358,9 @@ class OptimizerHistory: function based on the provided history. """ + # optimal point values + MIN_KEYS = (X, *HistoryBase.RESULT_KEYS) + def __init__( self, history: History, @@ -1345,7 +1388,7 @@ def __init__( self.sres_min: Union[np.ndarray, None] = None if generate_from_history: - self._compute_vals_from_trace() + self._maybe_compute_init_and_min_vals_from_trace() def update( self, @@ -1355,8 +1398,9 @@ def update( result: ResultDict, ) -> None: """Update history and best found value.""" - self.history.update(x, sensi_orders, mode, result) + result = add_fun_from_res(result) self._update_vals(x, result) + self.history.update(x, sensi_orders, mode, result) def finalize(self, message: str = None, exitflag: int = None): """ @@ -1371,148 +1415,186 @@ def finalize(self, message: str = None, exitflag: int = None): """ self.history.finalize(message=message, exitflag=exitflag) - def _update_vals(self, x: np.ndarray, result: ResultDict): + # There can be entries in the history e.g. for grad that are not + # recorded in ..._min, e.g. when evaluated before fval. + # On the other hand, not all variables may be recorded in the history. + # Thus, here at the end we go over the history once and try to fill + # in what is available. + + # check if a useful history exists + # TODO Y This can be solved prettier + try: + self.history.get_x_trace() + except NotImplementedError: + return + + # find optimal point + result = self._get_optimal_point_from_history() + + fval = result[FVAL] + if fval is None: + # nothing to be improved + return + + # check if history has a better point (should not really happen) + if ( + fval < self.fval_min + and not isclose(fval, self.fval_min) + and not allclose(result[X], self.x_min) + ): + # issue a warning, as if this happens, then something may be wrong + logger.warn( + f"History has a better point {fval} than the current best " + "point {self.fval_min}." + ) + # update everything + for key in self.MIN_KEYS: + setattr(self, key + '_min', result[key]) + + # check if history has same point + if isclose(fval, self.fval_min) and allclose(result[X], self.x_min): + # update only missing entries + # (e.g. grad and hess may be recorded but not in history) + for key in self.MIN_KEYS: + if result[key] is not None: + # if getattr(self, f'{key}_min') is None: + setattr(self, f'{key}_min', result[key]) + + def _update_vals(self, x: np.ndarray, result: ResultDict) -> None: """Update initial and best function values.""" # update initial point - if np.allclose(x, self.x0): - if self.fval0 is None: - self.fval0 = result.get(FVAL, None) - self.x0 = x + if is_none_or_nan(self.fval0) and np.array_equal(x, self.x0): + self.fval0 = result.get(FVAL) # don't update optimal point if point is not admissible if not self._admissible(x): return - # update best point - fval = result.get(FVAL, None) - grad = result.get(GRAD, None) - hess = result.get(HESS, None) - res = result.get(RES, None) - sres = result.get(SRES, None) - - if fval is not None and fval < self.fval_min: - self.fval_min = fval + # update if fval is better + if ( + not is_none_or_nan(fval := result.get(FVAL)) + and fval < self.fval_min + ): + # need to update all values, as better fval found + for key in HistoryBase.RESULT_KEYS: + setattr(self, f'{key}_min', result.get(key)) self.x_min = x - self.grad_min = grad - self.hess_min = hess - self.res_min = res - self.sres_min = sres - - # sometimes sensitivities are evaluated on subsequent calls. We can - # identify this situation by checking that x hasn't changed - if self.x_min is not None and np.allclose(self.x_min, x): - if self.grad_min is None and grad is not None: - self.grad_min = grad - if self.hess_min is None and hess is not None: - self.hess_min = hess - if self.res_min is None and res is not None: - self.res_min = res - if self.sres_min is None and sres is not None: - self.sres_min = sres - - def _compute_vals_from_trace(self): - """Set initial and best function value from trace (at start).""" + return + + # Sometimes sensitivities are evaluated on subsequent calls. We can + # identify this situation by checking that x hasn't changed. + if self.x_min is not None and np.array_equal(self.x_min, x): + for key in (GRAD, HESS, SRES): + val_min = getattr(self, f'{key}_min', None) + if is_none_or_nan_array(val_min) and not is_none_or_nan_array( + val := result.get(key) + ): + setattr(self, f'{key}_min', val) + + def _maybe_compute_init_and_min_vals_from_trace(self) -> None: + """Try to set initial and best function value from trace. + + Only possible if history has a trace. + """ if not len(self.history): # nothing to be computed from empty history return # some optimizers may evaluate hess+grad first to compute trust region # etc - max_init_iter = 3 - for it in range(min(len(self.history), max_init_iter)): - candidate = self.history.get_fval_trace(it) - if not np.isnan(candidate) and np.allclose( - self.history.get_x_trace(it), self.x0 - ): - self.fval0 = float(candidate) + for ix in range(len(self.history)): + fval = self.history.get_fval_trace(ix) + x = self.history.get_x_trace(ix) + if not is_none_or_nan(fval) and allclose(x, self.x0): + self.fval0 = float(fval) break + # find best fval + result = self._get_optimal_point_from_history() + + # assign values + for key in OptimizerHistory.MIN_KEYS: + setattr(self, f'{key}_min', result[key]) + + def _admissible(self, x: np.ndarray) -> bool: + """Check whether point `x` is admissible (i.e. within bounds). + + Parameters + ---------- + x: A single parameter vector. + + Returns + ------- + admissible: Whether the point fulfills the problem requirements. + """ + return np.all(x <= self.ub) and np.all(x >= self.lb) + + def _get_optimal_point_from_history(self) -> ResultDict: + """Extract optimal point from `self.history`.""" + result = {} + # get indices of admissible trace entries # shape (n_sample, n_x) xs = np.asarray(self.history.get_x_trace()) ixs_admit = [ix for ix, x in enumerate(xs) if self._admissible(x)] - # we prioritize fval over chi2 as fval is written whenever possible + if len(ixs_admit) == 0: + # no admittable indices + return {key: None for key in OptimizerHistory.MIN_KEYS} + + # index of minimum of fval values ix_min = np.nanargmin(self.history.get_fval_trace(ixs_admit)) - # np.argmin returns ndarray when multiple minimal values are found, we - # generally want the first occurrence + # np.argmin returns ndarray when multiple minimal values are found, + # we want the first occurrence if isinstance(ix_min, np.ndarray): ix_min = ix_min[0] # select index in original array ix_min = ixs_admit[ix_min] - for var in ['fval', 'chi2', 'x']: - self.extract_from_history(var, ix_min) - if var == 'fval': - self.fval_min = float(self.fval_min) - - for var in ['res', 'grad', 'sres', 'hess']: - if not getattr(self.history.options, f'trace_record_{var}'): - continue # var not saved in history - # first try index of optimal function value - if self.extract_from_history(var, ix_min): - continue - # gradients may be evaluated at different indices, therefore - # iterate over all and check whether any has the same parameter - # and the desired field filled - # for res we do the same because otherwise randomly None - # (TODO investigate why, but ok this way) - for ix in reversed(range(len(self.history))): - if not np.allclose(self.x_min, self.history.get_x_trace(ix)): + # fill in parameter and function value from that index + for var in (X, FVAL, RES): + val = getattr(self.history, f'get_{var}_trace')(ix_min) + if val is not None and not np.all(np.isnan(val)): + result[var] = val + # convert to float if var is FVAL to be sure + if var == FVAL: + result[var] = float(result[var]) + + # derivatives may be evaluated at different indices, therefore + # iterate over all and check whether any has the same parameter + # and the desired field filled + for var in (GRAD, HESS, SRES): + for ix in range(len(self.history)): + if not allclose(result[X], self.history.get_x_trace(ix)): + # different parameter continue - if self.extract_from_history(var, ix): - # successfully assigned + val = getattr(self.history, f'get_{var}_trace')(ix) + if not is_none_or_nan_array(val): + result[var] = val + # successfuly found break - def extract_from_history(self, var: str, ix: int) -> bool: - """Get value of `var` at iteration `ix` and assign to `{var}_min`. - - Parameters - ---------- - var: Variable to extract, e.g. 'grad', 'x'. - ix: Trace index. - - Returns - ------- - successful: - Whether extraction and assignment worked. False in particular if - the history value is nan. - """ - val = getattr(self.history, f'get_{var}_trace')(ix) - if not np.all(np.isnan(val)): - setattr(self, f'{var}_min', val) - return True - return False - - def _admissible(self, x: np.ndarray) -> bool: - """Check whether point `x` is admissible (i.e. within bounds). - - Parameters - ---------- - x: A single parameter vector. + # fill remaining keys with None + for key in OptimizerHistory.MIN_KEYS: + if key not in result: + result[key] = None - Returns - ------- - admissible: Whether the point fulfills the problem requirements. - """ - return np.all(x <= self.ub) and np.all(x >= self.lb) + return result def ndarray2string_full(x: Union[np.ndarray, None]) -> Union[str, None]: """ - Convert numpy arrays to string. + Convert numpy array to string. - Use 16 digit numerical precision and no truncation for large arrays + Use 16-digit numerical precision and no truncation for large arrays. Parameters ---------- - x: - array to convert + x: array to convert. Returns ------- - x: - array as string + x: array as string. """ if not isinstance(x, np.ndarray): return x @@ -1523,17 +1605,15 @@ def ndarray2string_full(x: Union[np.ndarray, None]) -> Union[str, None]: def string2ndarray(x: Union[str, float]) -> Union[np.ndarray, float]: """ - Convert string to numpy arrays. + Convert string to numpy array. Parameters ---------- - x: - array to convert + x: array to convert. Returns ------- - x: - array as np.ndarray + x: array as np.ndarray. """ if not isinstance(x, str): return x @@ -1545,40 +1625,63 @@ def string2ndarray(x: Union[str, float]) -> Union[np.ndarray, float]: return np.fromstring(x[1:-1], sep=' ') -def extract_values( - mode: ModeType, result: ResultDict, options: HistoryOptions -) -> Dict: - """Extract values to record from result.""" - ret = {} - ret_vars = [FVAL, GRAD, HESS, RES, SRES, CHI2, SCHI2] - for var in ret_vars: - if options.get(f'trace_record_{var}', True) and var in result: - ret[var] = result[var] - - # write values that weren't set yet with alternative methods - if mode == MODE_RES: - res_result = result.get(RES, None) - sres_result = result.get(SRES, None) - chi2 = res_to_chi2(res_result) - schi2 = sres_to_schi2(res_result, sres_result) - fim = sres_to_fim(sres_result) - alt_values = {CHI2: chi2, SCHI2: schi2, HESS: fim} - if schi2 is not None: - alt_values[GRAD] = schi2_to_grad(schi2) - - # filter according to options - alt_values = { - key: val - for key, val in alt_values.items() - if options.get(f'trace_record_{key}', True) - } - for var, val in alt_values.items(): - if val is not None: - ret[var] = ret.get(var, val) +def add_fun_from_res(result: ResultDict) -> ResultDict: + """Calculate function values from residual values. + + Copies the result, but apart performs calculations only if entries + are not present yet in the result object + (thus can be called repeatedly). + + Parameters + ---------- + result: Result dictionary from the objective function. - # set everything missing to NaN - for var in ret_vars: - if var not in ret: - ret[var] = np.NaN + Returns + ------- + full_result: + Result dicionary, adding whatever is possible to calculate. + """ + result = result.copy() + + # calculate function values from residuals + if result.get(CHI2) is None: + result[CHI2] = res_to_chi2(result.get(RES)) + if result.get(SCHI2) is None: + result[SCHI2] = sres_to_schi2(result.get(RES), result.get(SRES)) + if result.get(FVAL) is None: + result[FVAL] = chi2_to_fval(result.get(CHI2)) + if result.get(GRAD) is None: + result[GRAD] = schi2_to_grad(result.get(SCHI2)) + if result.get(HESS) is None: + result[HESS] = sres_to_fim(result.get(SRES)) + + return result + + +def reduce_result_via_options( + result: ResultDict, options: HistoryOptions +) -> ResultDict: + """Set values not to be stored in history or missing to NaN. + + Parameters + ---------- + result: + Result dictionary with all fields present. + options: + History options. + + Returns + ------- + result: + Result reduced to what is intended to be stored in history. + """ + result = result.copy() + + # apply options to result + for key in HistoryBase.FULL_RESULT_KEYS: + if result.get(key) is None or not options.get( + f'trace_record_{key}', True + ): + result[key] = np.nan - return ret + return result diff --git a/pypesto/objective/util.py b/pypesto/objective/util.py index 6150dabcb..6bd2ac497 100644 --- a/pypesto/objective/util.py +++ b/pypesto/objective/util.py @@ -1,11 +1,11 @@ """Objective utilities.""" -from typing import Union +from typing import Any, Callable, Union import numpy as np -def _check_none(fun): +def _check_none(fun: Callable[..., Any]) -> Callable[..., Union[Any, None]]: """Return None if any input argument is None; Wrapper function.""" def checked_fun(*args, **kwargs): @@ -17,14 +17,14 @@ def checked_fun(*args, **kwargs): @_check_none -def res_to_chi2(res: np.ndarray) -> Union[float, None]: +def res_to_chi2(res: np.ndarray) -> float: """Translate residuals to chi2 values, `chi2 = sum(res**2)`.""" return float(np.dot(res, res)) @_check_none -def chi2_to_fval(chi2: float) -> Union[float, None]: - """Translate chi2 to function value, `fval = 0.5*chi2 = 0.5*sum(res**2)`. +def chi2_to_fval(chi2: float) -> float: + """Translate chi2 to function value, `fval = 0.5*chi2 = 0.5*sum(res**2) + C`. Note that for the function value we thus employ a probabilistic interpretation, as the log-likelihood of a standard normal noise model. @@ -34,13 +34,13 @@ def chi2_to_fval(chi2: float) -> Union[float, None]: @_check_none -def res_to_fval(res: np.ndarray) -> Union[float, None]: - """Translate residuals to function value, `fval = 0.5*sum(res**2)`.""" +def res_to_fval(res: np.ndarray) -> float: + """Translate residuals to function value, `fval = 0.5*sum(res**2) + C`.""" return chi2_to_fval(res_to_chi2(res)) @_check_none -def sres_to_schi2(res: np.ndarray, sres: np.ndarray): +def sres_to_schi2(res: np.ndarray, sres: np.ndarray) -> np.ndarray: """Translate residual sensitivities to chi2 gradient.""" return 2 * res.dot(sres) @@ -55,7 +55,7 @@ def schi2_to_grad(schi2: np.ndarray) -> np.ndarray: @_check_none -def sres_to_grad(res: np.ndarray, sres: np.ndarray): +def sres_to_grad(res: np.ndarray, sres: np.ndarray) -> np.ndarray: """Translate residual sensitivities to function value gradient. Assumes `fval = 0.5*sum(res**2)`. @@ -66,7 +66,7 @@ def sres_to_grad(res: np.ndarray, sres: np.ndarray): @_check_none -def sres_to_fim(sres: np.ndarray): +def sres_to_fim(sres: np.ndarray) -> np.ndarray: """Translate residual sensitivities to FIM. The FIM is based on the function values, not chi2, i.e. has a normalization diff --git a/pypesto/optimize/load.py b/pypesto/optimize/load.py index e0f36dae3..f67a80d08 100644 --- a/pypesto/optimize/load.py +++ b/pypesto/optimize/load.py @@ -2,17 +2,30 @@ import logging import os +from pathlib import Path import h5py import numpy as np import pypesto -from ..C import FVAL, GRAD, HESS, RES, SRES, X +from ..C import ( + FVAL, + GRAD, + HESS, + HISTORY, + RES, + SRES, + SUFFIXES_CSV, + SUFFIXES_HDF5, + TRACE, + X, +) from ..objective import ( CsvHistory, Hdf5History, HistoryOptions, + HistoryTypeError, OptimizerHistory, ) from ..problem import Problem @@ -51,8 +64,8 @@ def fill_result_from_history( fval_match = fval_exist and np.isclose(history_fval, result_fval) if fval_exist and not fval_match: logger.debug( - f"Minimal function value mismatch: history {history_fval:8e}, " - f"result {result_fval:8e}" + "Minimal function value mismatch: " + f"history {history_fval:8e}, result {result_fval:8e}" ) # parameters history_x, result_x = optimizer_history.x_min, result.x @@ -82,6 +95,7 @@ def fill_result_from_history( if not optimize_options.history_beats_optimizer: return result + # exit flag and message if isinstance(optimizer_history.history, Hdf5History): if (message := optimizer_history.history.message) is not None: result.message = message @@ -118,19 +132,23 @@ def read_result_from_file( if history_options.storage_file is None: raise ValueError("No history file specified.") - if history_options.storage_file.endswith('.csv'): + # evaluate type + suffix = Path(history_options.storage_file).suffix[1:] + + if suffix in SUFFIXES_CSV: history = CsvHistory( file=history_options.storage_file.format(id=identifier), options=history_options, load_from_file=True, ) - elif history_options.storage_file.endswith(('.h5', '.hdf5')): + elif suffix in SUFFIXES_HDF5: history = Hdf5History.load( id=identifier, file=history_options.storage_file.format(id=identifier), + options=history_options, ) else: - raise NotImplementedError() + raise HistoryTypeError(suffix) opt_hist = OptimizerHistory( history=history, @@ -218,12 +236,12 @@ def optimization_result_from_history( """ result = Result() with h5py.File(filename, 'r') as f: - for id_name in f['history'].keys(): + for id_name in f[HISTORY].keys(): history = Hdf5History(id=id_name, file=filename) history.recover_options(filename) optimizer_history = OptimizerHistory( history=history, - x0=f[f'history/{id_name}/trace/0/x'][()], + x0=f[f'{HISTORY}/{id_name}/{TRACE}/0/{X}'][()], lb=problem.lb, ub=problem.ub, generate_from_history=True, diff --git a/pypesto/optimize/util.py b/pypesto/optimize/util.py index 1eb3f6f6c..68228f78a 100644 --- a/pypesto/optimize/util.py +++ b/pypesto/optimize/util.py @@ -7,9 +7,13 @@ import h5py -from ..C import PYPESTO_MAX_N_STARTS +from .. import C from ..engine import Engine, SingleCoreEngine -from ..objective import HistoryOptions +from ..objective import ( + CsvHistoryTemplateError, + HistoryOptions, + HistoryTypeError, +) from ..result import Result from ..store.save_to_hdf5 import get_or_create_group from .optimizer import OptimizerResult @@ -47,20 +51,14 @@ def preprocess_hdf5_history( path = Path(storage_file) # nothing to do if csv history and correctly set - if path.suffix == ".csv": + if path.suffix[1:] in C.SUFFIXES_CSV: if "{id}" not in storage_file: - raise ValueError( - "For csv history, the `storage_file` must contain an `{id}` " - "template" - ) + raise CsvHistoryTemplateError(storage_file) return False # assuming hdf5 history henceforth - if path.suffix not in [".h5", ".hdf5"]: - raise ValueError( - "Only history storage to '.csv' and '.hdf5' is supported, got " - f"{path.suffix}", - ) + if path.suffix[1:] not in C.SUFFIXES_HDF5: + raise HistoryTypeError(path.suffix) # nothing to do if no parallelization if isinstance(engine, SingleCoreEngine): @@ -129,13 +127,13 @@ def bound_n_starts_from_env(n_starts: int): The original number of starts, or the minimum with the environment variable, if exists. """ - if PYPESTO_MAX_N_STARTS not in os.environ: + if C.PYPESTO_MAX_N_STARTS not in os.environ: return n_starts - n_starts_new = min(n_starts, int(os.environ[PYPESTO_MAX_N_STARTS])) + n_starts_new = min(n_starts, int(os.environ[C.PYPESTO_MAX_N_STARTS])) logger.info( f"Bounding number of samples from {n_starts} to {n_starts_new} via " - f"environment variable {PYPESTO_MAX_N_STARTS}" + f"environment variable {C.PYPESTO_MAX_N_STARTS}" ) return n_starts_new diff --git a/pypesto/util.py b/pypesto/util.py index 0be236f57..55097c9a5 100644 --- a/pypesto/util.py +++ b/pypesto/util.py @@ -5,12 +5,86 @@ Package-wide utilities. """ -from typing import Optional, Tuple +from numbers import Number +from typing import Optional, Tuple, Union import numpy as np from scipy import cluster +def is_none_or_nan(x: Union[Number, None]) -> bool: + """ + Check if x is None or NaN. + + Parameters + ---------- + x: + object to be checked + + Returns + ------- + True if x is None or NaN, False otherwise. + """ + return x is None or np.isnan(x) + + +def is_none_or_nan_array(x: Union[Number, np.ndarray, None]) -> bool: + """ + Check if x is None or NaN array. + + Parameters + ---------- + x: + object to be checked + + Returns + ------- + True if x is None or NaN array, False otherwise. + """ + return x is None or np.isnan(x).all() + + +def allclose( + x: Union[Number, np.ndarray], y: Union[Number, np.ndarray] +) -> bool: + """ + Check if two arrays are close. + + Parameters + ---------- + x: first array + y: second array + + Returns + ------- + True if all elements of x and y are close, False otherwise. + """ + # Note: We use this wrapper around np.allclose in order to more easily + # adjust hyper parameters for the tolerance. + return np.allclose(x, y) + + +def isclose( + x: Union[Number, np.ndarray], + y: Union[Number, np.ndarray], +) -> Union[bool, np.ndarray]: + """ + Check if two values or arrays are close, element-wise. + + Parameters + ---------- + x: first array + y: second array + + Returns + ------- + Element-wise boolean comparison of x and y. + """ + # Note: We use this wrapper around np.isclose in order to more easily + # adjust hyper parameters for the tolerance. + return np.isclose(x, y) + + def get_condition_label(condition_id: str) -> str: """Convert a condition ID to a label. diff --git a/test/base/test_history.py b/test/base/test_history.py index 2b564aaba..2b47441b7 100644 --- a/test/base/test_history.py +++ b/test/base/test_history.py @@ -609,27 +609,27 @@ def test_trace_subset(history: pypesto.History): partial_trace = getter(arr) # check partial traces coincide - assert len(partial_trace) == len(arr) + assert len(partial_trace) == len(arr), var for a, b in zip(partial_trace, [full_trace[i] for i in arr]): if var != 'schi2': - assert np.all(a == b) or np.isnan(a) and np.isnan(b) + assert np.all(a == b) or np.isnan(a) and np.isnan(b), var else: assert ( np.all(a == b) or np.all(np.isnan(a)) and np.all(np.isnan(b)) - ) + ), var # check sequence type - assert isinstance(full_trace, Sequence) - assert isinstance(partial_trace, Sequence) + assert isinstance(full_trace, Sequence), var + assert isinstance(partial_trace, Sequence), var # check individual type val = getter(0) if var in ['fval', 'chi2', 'time']: - assert isinstance(val, float) + assert isinstance(val, float), var else: - assert isinstance(val, np.ndarray) or np.isnan(val) + assert isinstance(val, np.ndarray) or np.isnan(val), var def test_hdf5_history_mp(): diff --git a/test/base/test_store.py b/test/base/test_store.py index ada4c775c..29a8aa73e 100644 --- a/test/base/test_store.py +++ b/test/base/test_store.py @@ -64,13 +64,10 @@ def test_storage_opt_result(): for key in opt_res: if isinstance(opt_res[key], np.ndarray): np.testing.assert_array_equal( - opt_res[key], read_result.optimize_result.list[i][key] + opt_res[key], read_result.optimize_result[i][key] ) else: - assert ( - opt_res[key] - == read_result.optimize_result.list[i][key] - ) + assert opt_res[key] == read_result.optimize_result[i][key] def test_storage_opt_result_update(hdf5_file): @@ -88,10 +85,10 @@ def test_storage_opt_result_update(hdf5_file): for key in opt_res: if isinstance(opt_res[key], np.ndarray): np.testing.assert_array_equal( - opt_res[key], read_result.optimize_result.list[i][key] + opt_res[key], read_result.optimize_result[i][key] ) else: - assert opt_res[key] == read_result.optimize_result.list[i][key] + assert opt_res[key] == read_result.optimize_result[i][key] def test_storage_problem(hdf5_file): @@ -290,7 +287,7 @@ def test_storage_sampling(): filename=None, progress_bar=False, ) - x_0 = result_optimization.optimize_result.list[0]['x'] + x_0 = result_optimization.optimize_result[0]['x'] sampler = sample.AdaptiveParallelTemperingSampler( internal_sampler=sample.AdaptiveMetropolisSampler(), options={ @@ -391,13 +388,10 @@ def test_storage_all(): continue if isinstance(opt_res[key], np.ndarray): np.testing.assert_array_equal( - opt_res[key], result_read.optimize_result.list[i][key] + opt_res[key], result_read.optimize_result[i][key] ) else: - assert ( - opt_res[key] - == result_read.optimize_result.list[i][key] - ) + assert opt_res[key] == result_read.optimize_result[i][key] # test profile for key in result.profile_result.list[0][0].keys(): @@ -484,18 +478,27 @@ def test_storage_objective_config(): def test_result_from_hdf5_history(hdf5_file): - problem = create_petab_problem() + """Test whether we can recover a result from a hdf5 file. + + This means that the result obtained directly via minimize, and the result + obtained from the history should coincide. + For this aim, run a simple problem and record the full history, such that + recovery should be possible. + """ + problem = create_petab_problem() history_options_hdf5 = pypesto.HistoryOptions( trace_record=True, storage_file=hdf5_file, ) + # optimize with history saved to hdf5 result = optimize.minimize( problem=problem, n_starts=1, history_options=history_options_hdf5, progress_bar=False, + options={"allow_failed_starts": False}, ) result_from_hdf5 = optimization_result_from_history( @@ -522,15 +525,15 @@ def test_result_from_hdf5_history(hdf5_file): MESSAGE, ] for key in arguments: - if result.optimize_result.list[0][key] is None: - assert result_from_hdf5.optimize_result.list[0][key] is None - elif isinstance(result.optimize_result.list[0][key], np.ndarray): + if result.optimize_result[0][key] is None: + assert result_from_hdf5.optimize_result[0][key] is None, key + elif isinstance(result.optimize_result[0][key], np.ndarray): assert np.allclose( - result.optimize_result.list[0][key], - result_from_hdf5.optimize_result.list[0][key], + result.optimize_result[0][key], + result_from_hdf5.optimize_result[0][key], ), key else: assert ( - result.optimize_result.list[0][key] - == result_from_hdf5.optimize_result.list[0][key] + result.optimize_result[0][key] + == result_from_hdf5.optimize_result[0][key] ), key diff --git a/tox.ini b/tox.ini index 5c49a10ac..496a1be0e 100644 --- a/tox.ini +++ b/tox.ini @@ -119,12 +119,13 @@ deps = flake8-bandit >= 2.1.2 flake8-bugbear >= 20.1.4 flake8-colors >= 0.1.6 - # flake8-commas >= 2.0.0 flake8-comprehensions >= 3.2.3 flake8-print >= 5.0.0 + flake8-black >= 0.2.3 + flake8-isort >= 4.0.0 flake8-docstrings >= 1.6.0 commands = - flake8 pypesto test setup.py + flake8 pypesto test description = Run flake8 with various plugins.