diff --git a/pypesto/store/save_to_hdf5.py b/pypesto/store/save_to_hdf5.py index 3804bfbbb..a4ac4e703 100644 --- a/pypesto/store/save_to_hdf5.py +++ b/pypesto/store/save_to_hdf5.py @@ -1,23 +1,22 @@ """Include functions for saving various results to hdf5.""" +from __future__ import annotations import logging import os from numbers import Integral from pathlib import Path -from typing import Union import h5py import numpy as np +from .. import OptimizeResult, OptimizerResult from ..result import ProfilerResult, Result, SampleResult from .hdf5 import write_array, write_float_array logger = logging.getLogger(__name__) -def check_overwrite( - f: Union[h5py.File, h5py.Group], overwrite: bool, target: str -): +def check_overwrite(f: h5py.File | h5py.Group, overwrite: bool, target: str): """ Check whether target already exists. @@ -36,7 +35,7 @@ def check_overwrite( del f[target] else: raise RuntimeError( - f"File `{f.filename}` already exists and contains " + f"File `{f.file.filename}` already exists and contains " f"information about {target} result. " f"If you wish to overwrite the file, set " f"`overwrite=True`." @@ -53,7 +52,7 @@ class ProblemHDF5Writer: HDF5 result file name """ - def __init__(self, storage_filename: Union[str, Path]): + def __init__(self, storage_filename: str | Path): """ Initialize writer. @@ -106,7 +105,7 @@ class OptimizationResultHDF5Writer: HDF5 result file name """ - def __init__(self, storage_filename: Union[str, Path]): + def __init__(self, storage_filename: str | Path): """ Initialize Writer. @@ -117,32 +116,76 @@ def __init__(self, storage_filename: Union[str, Path]): """ self.storage_filename = str(storage_filename) - def write(self, result: Result, overwrite=False): - """Write HDF5 result file from pyPESTO result object.""" - # Create destination directory - if isinstance(self.storage_filename, str): - basedir = os.path.dirname(self.storage_filename) - if basedir: - os.makedirs(basedir, exist_ok=True) + def write( + self, + result: Result + | OptimizeResult + | OptimizerResult + | list[OptimizerResult], + overwrite=False, + ): + """Write HDF5 result file from pyPESTO result object. + + Parameters + ---------- + result: Result to be saved. + overwrite: Boolean, whether already existing results should be + overwritten. This applies to the whole list of results, not only to + individual results. See :meth:`write_optimizer_result` for + incrementally writing a sequence of `OptimizerResult`. + """ + Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True) + + if isinstance(result, Result): + results = result.optimize_result.list + elif isinstance(result, OptimizeResult): + results = result.list + elif isinstance(result, list): + results = result + elif isinstance(result, OptimizerResult): + results = [result] + else: + raise ValueError(f"Unsupported type for `result`: {type(result)}.") with h5py.File(self.storage_filename, "a") as f: check_overwrite(f, overwrite, "optimization") optimization_grp = f.require_group("optimization") - # settings = - # optimization_grp.create_dataset("settings", settings, dtype=) results_grp = optimization_grp.require_group("results") - for start in result.optimize_result.list: - start_id = start["id"] - start_grp = results_grp.require_group(start_id) - for key in start.keys(): - if key == "history": - continue - if isinstance(start[key], np.ndarray): - write_array(start_grp, key, start[key]) - elif start[key] is not None: - start_grp.attrs[key] = start[key] - f.flush() + for start in results: + self._do_write_optimizer_result(start, results_grp, overwrite) + + def write_optimizer_result( + self, result: OptimizerResult, overwrite: bool = False + ): + """Write HDF5 result file from pyPESTO result object. + + Parameters + ---------- + result: Result to be saved. + overwrite: Boolean, whether already existing results with the same ID + should be overwritten.s + """ + Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True) + + with h5py.File(self.storage_filename, "a") as f: + results_grp = f.require_group("optimization/results") + self._do_write_optimizer_result(result, results_grp, overwrite) + + def _do_write_optimizer_result( + self, result: OptimizerResult, g: h5py.Group = None, overwrite=False + ): + """Write an OptimizerResult to the given group.""" + sub_group_id = result["id"] + check_overwrite(g, overwrite, sub_group_id) + start_grp = g.require_group(sub_group_id) + for key in result.keys(): + if key == "history": + continue + if isinstance(result[key], np.ndarray): + write_array(start_grp, key, result[key]) + elif result[key] is not None: + start_grp.attrs[key] = result[key] class SamplingResultHDF5Writer: @@ -155,7 +198,7 @@ class SamplingResultHDF5Writer: HDF5 result file name """ - def __init__(self, storage_filename: Union[str, Path]): + def __init__(self, storage_filename: str | Path): """ Initialize Writer. @@ -208,7 +251,7 @@ class ProfileResultHDF5Writer: HDF5 result file name """ - def __init__(self, storage_filename: Union[str, Path]): + def __init__(self, storage_filename: str | Path): """ Initialize Writer. @@ -241,7 +284,7 @@ def write(self, result: Result, overwrite: bool = False): @staticmethod def _write_profiler_result( - parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group + parameter_profile: ProfilerResult | None, result_grp: h5py.Group ) -> None: """Write a single ProfilerResult to hdf5. @@ -267,7 +310,7 @@ def _write_profiler_result( def write_result( result: Result, - filename: Union[str, Path], + filename: str | Path, overwrite: bool = False, problem: bool = True, optimize: bool = False, diff --git a/test/base/test_store.py b/test/base/test_store.py index d90a8030d..840440c70 100644 --- a/test/base/test_store.py +++ b/test/base/test_store.py @@ -1,9 +1,11 @@ """Test the `pypesto.store` module.""" import os -import tempfile +from pathlib import Path +from tempfile import TemporaryDirectory import numpy as np +import pytest import scipy.optimize as so import pypesto @@ -52,7 +54,7 @@ def test_storage_opt_result(): minimize_result = create_optimization_result() - with tempfile.TemporaryDirectory(dir=".") as tmpdirname: + with TemporaryDirectory(dir=".") as tmpdirname: result_file_name = os.path.join(tmpdirname, "a", "b", "result.h5") opt_result_writer = OptimizationResultHDF5Writer(result_file_name) opt_result_writer.write(minimize_result) @@ -89,6 +91,27 @@ def test_storage_opt_result_update(hdf5_file): assert opt_res[key] == read_result.optimize_result[i][key] +def test_write_optimizer_results_incrementally(): + """Test writing optimizer results incrementally to the same file.""" + res = create_optimization_result() + res1, res2 = res.optimize_result.list[:2] + + with TemporaryDirectory() as tmp_dir: + result_path = Path(tmp_dir, "result.h5") + writer = OptimizationResultHDF5Writer(result_path) + writer.write_optimizer_result(res1) + writer.write_optimizer_result(res2) + reader = OptimizationResultHDF5Reader(result_path) + read_res = reader.read() + assert len(read_res.optimize_result) == 2 + + # overwriting works + writer.write_optimizer_result(res1, overwrite=True) + # overwriting attempt fails without overwrite=True + with pytest.raises(RuntimeError): + writer.write_optimizer_result(res1) + + def test_storage_problem(hdf5_file): problem = create_problem() problem_writer = ProblemHDF5Writer(hdf5_file)