Skip to content

Commit

Permalink
Previously, only writing Result was supported. (#1528)
Browse files Browse the repository at this point in the history
Changes:
*  `OptimizeResult`, `OptimizerResult`(s) can be written by `OptimizationResultHDF5Writer.write()` (Closes #1526)
*  `OptimizerResult`s can be written incrementally by `OptimizationResultHDF5Writer.write_optimizer_result()` (Closes #1527)
* Fixed an `AttributeError` in `pypesto.store.save_to_hdf5.check_overwrite` with `h5py.Group`s
  • Loading branch information
dweindl authored Nov 28, 2024
1 parent 2ed5f8f commit 861318a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 33 deletions.
105 changes: 74 additions & 31 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Include functions for saving various results to hdf5."""
from __future__ import annotations

import logging
import os
from numbers import Integral
from pathlib import Path
from typing import Union

import h5py
import numpy as np

from .. import OptimizeResult, OptimizerResult
from ..result import ProfilerResult, Result, SampleResult
from .hdf5 import write_array, write_float_array

logger = logging.getLogger(__name__)


def check_overwrite(
f: Union[h5py.File, h5py.Group], overwrite: bool, target: str
):
def check_overwrite(f: h5py.File | h5py.Group, overwrite: bool, target: str):
"""
Check whether target already exists.
Expand All @@ -36,7 +35,7 @@ def check_overwrite(
del f[target]
else:
raise RuntimeError(
f"File `{f.filename}` already exists and contains "
f"File `{f.file.filename}` already exists and contains "
f"information about {target} result. "
f"If you wish to overwrite the file, set "
f"`overwrite=True`."
Expand All @@ -53,7 +52,7 @@ class ProblemHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize writer.
Expand Down Expand Up @@ -106,7 +105,7 @@ class OptimizationResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand All @@ -117,32 +116,76 @@ def __init__(self, storage_filename: Union[str, Path]):
"""
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite=False):
"""Write HDF5 result file from pyPESTO result object."""
# Create destination directory
if isinstance(self.storage_filename, str):
basedir = os.path.dirname(self.storage_filename)
if basedir:
os.makedirs(basedir, exist_ok=True)
def write(
self,
result: Result
| OptimizeResult
| OptimizerResult
| list[OptimizerResult],
overwrite=False,
):
"""Write HDF5 result file from pyPESTO result object.
Parameters
----------
result: Result to be saved.
overwrite: Boolean, whether already existing results should be
overwritten. This applies to the whole list of results, not only to
individual results. See :meth:`write_optimizer_result` for
incrementally writing a sequence of `OptimizerResult`.
"""
Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True)

if isinstance(result, Result):
results = result.optimize_result.list
elif isinstance(result, OptimizeResult):
results = result.list
elif isinstance(result, list):
results = result
elif isinstance(result, OptimizerResult):
results = [result]
else:
raise ValueError(f"Unsupported type for `result`: {type(result)}.")

with h5py.File(self.storage_filename, "a") as f:
check_overwrite(f, overwrite, "optimization")
optimization_grp = f.require_group("optimization")
# settings =
# optimization_grp.create_dataset("settings", settings, dtype=)
results_grp = optimization_grp.require_group("results")

for start in result.optimize_result.list:
start_id = start["id"]
start_grp = results_grp.require_group(start_id)
for key in start.keys():
if key == "history":
continue
if isinstance(start[key], np.ndarray):
write_array(start_grp, key, start[key])
elif start[key] is not None:
start_grp.attrs[key] = start[key]
f.flush()
for start in results:
self._do_write_optimizer_result(start, results_grp, overwrite)

def write_optimizer_result(
self, result: OptimizerResult, overwrite: bool = False
):
"""Write HDF5 result file from pyPESTO result object.
Parameters
----------
result: Result to be saved.
overwrite: Boolean, whether already existing results with the same ID
should be overwritten.s
"""
Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True)

with h5py.File(self.storage_filename, "a") as f:
results_grp = f.require_group("optimization/results")
self._do_write_optimizer_result(result, results_grp, overwrite)

def _do_write_optimizer_result(
self, result: OptimizerResult, g: h5py.Group = None, overwrite=False
):
"""Write an OptimizerResult to the given group."""
sub_group_id = result["id"]
check_overwrite(g, overwrite, sub_group_id)
start_grp = g.require_group(sub_group_id)
for key in result.keys():
if key == "history":
continue
if isinstance(result[key], np.ndarray):
write_array(start_grp, key, result[key])
elif result[key] is not None:
start_grp.attrs[key] = result[key]


class SamplingResultHDF5Writer:
Expand All @@ -155,7 +198,7 @@ class SamplingResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand Down Expand Up @@ -208,7 +251,7 @@ class ProfileResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand Down Expand Up @@ -241,7 +284,7 @@ def write(self, result: Result, overwrite: bool = False):

@staticmethod
def _write_profiler_result(
parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group
parameter_profile: ProfilerResult | None, result_grp: h5py.Group
) -> None:
"""Write a single ProfilerResult to hdf5.
Expand All @@ -267,7 +310,7 @@ def _write_profiler_result(

def write_result(
result: Result,
filename: Union[str, Path],
filename: str | Path,
overwrite: bool = False,
problem: bool = True,
optimize: bool = False,
Expand Down
27 changes: 25 additions & 2 deletions test/base/test_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test the `pypesto.store` module."""

import os
import tempfile
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import scipy.optimize as so

import pypesto
Expand Down Expand Up @@ -52,7 +54,7 @@

def test_storage_opt_result():
minimize_result = create_optimization_result()
with tempfile.TemporaryDirectory(dir=".") as tmpdirname:
with TemporaryDirectory(dir=".") as tmpdirname:
result_file_name = os.path.join(tmpdirname, "a", "b", "result.h5")
opt_result_writer = OptimizationResultHDF5Writer(result_file_name)
opt_result_writer.write(minimize_result)
Expand Down Expand Up @@ -89,6 +91,27 @@ def test_storage_opt_result_update(hdf5_file):
assert opt_res[key] == read_result.optimize_result[i][key]


def test_write_optimizer_results_incrementally():
"""Test writing optimizer results incrementally to the same file."""
res = create_optimization_result()
res1, res2 = res.optimize_result.list[:2]

with TemporaryDirectory() as tmp_dir:
result_path = Path(tmp_dir, "result.h5")
writer = OptimizationResultHDF5Writer(result_path)
writer.write_optimizer_result(res1)
writer.write_optimizer_result(res2)
reader = OptimizationResultHDF5Reader(result_path)
read_res = reader.read()
assert len(read_res.optimize_result) == 2

# overwriting works
writer.write_optimizer_result(res1, overwrite=True)
# overwriting attempt fails without overwrite=True
with pytest.raises(RuntimeError):
writer.write_optimizer_result(res1)


def test_storage_problem(hdf5_file):
problem = create_problem()
problem_writer = ProblemHDF5Writer(hdf5_file)
Expand Down

0 comments on commit 861318a

Please sign in to comment.