From c46f3492d162bc850adcde9a91fd705abd521d89 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Sat, 30 Dec 2023 12:17:47 +0000 Subject: [PATCH] update --- CHANGES | 3 ++ devtools/conda-envs/test_env.yaml | 2 +- src/alchemlyb/parsing/parquet.py | 39 +++++++++++++++++++-- src/alchemlyb/tests/conftest.py | 2 +- src/alchemlyb/tests/parsing/test_parquet.py | 35 +++++++++++++++++- src/alchemlyb/tests/test_preprocessing.py | 6 ++-- src/alchemlyb/tests/test_workflow_ABFE.py | 2 +- 7 files changed, 79 insertions(+), 10 deletions(-) diff --git a/CHANGES b/CHANGES index 7a8ed754..6c3b5f60 100644 --- a/CHANGES +++ b/CHANGES @@ -17,6 +17,9 @@ The rules for this file: */*/2023 hl2500 * 2.2.0 + +Changes + - For pandas>=2.1, metadata will be loaded from the parquet file (issue #331, PR #326). Enhancements - Add a TI estimator using gaussian quadrature to calculate the free energy. diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 87470b55..a8795f1c 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -4,7 +4,7 @@ channels: dependencies: - python - numpy -- pandas +- pandas>=2.1 - pymbar>=4 - scipy - scikit-learn diff --git a/src/alchemlyb/parsing/parquet.py b/src/alchemlyb/parsing/parquet.py index 180817ae..c06664c7 100644 --- a/src/alchemlyb/parsing/parquet.py +++ b/src/alchemlyb/parsing/parquet.py @@ -1,9 +1,42 @@ import pandas as pd +from loguru import logger from . import _init_attrs -@_init_attrs +def _check_metadata(path: str, T: float) -> pd.DataFrame: + """ + Check if the metadata is included in the Dataframe and has the correct + temperature. + + Parameters + ---------- + path : str + Path to parquet file to extract dataframe from. + T : float + Temperature in Kelvin of the simulations. + + Returns + ------- + DataFrame + """ + df = pd.read_parquet(path) + if "temperature" not in df.attrs: + logger.warning( + f"No temperature metadata found in {path}. " + f"Serialise the Dataframe with pandas>=2.1 to preserve the metadata." + ) + df.attrs["temperature"] = T + df.attrs["energy_unit"] = "kT" + else: + if df.attrs["temperature"] != T: + raise ValueError( + f"Temperature in the input ({T}) doesn't match the temperature " + f"in the dataframe ({df.attrs['temperature']})." + ) + return df + + def extract_u_nk(path, T): r"""Return reduced potentials `u_nk` (unit: kT) from a pandas parquet file. @@ -36,7 +69,7 @@ def extract_u_nk(path, T): .. versionadded:: 2.1.0 """ - u_nk = pd.read_parquet(path) + u_nk = _check_metadata(path, T) columns = list(u_nk.columns) if isinstance(columns[0], str) and columns[0][0] == "(": new_columns = [] @@ -81,4 +114,4 @@ def extract_dHdl(path, T): .. versionadded:: 2.1.0 """ - return pd.read_parquet(path) + return _check_metadata(path, T) diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py index d7a3749d..ac216471 100644 --- a/src/alchemlyb/tests/conftest.py +++ b/src/alchemlyb/tests/conftest.py @@ -82,7 +82,7 @@ def gmx_ABFE(): @pytest.fixture -def gmx_ABFE_complex_n_uk(gmx_ABFE): +def gmx_ABFE_complex_u_nk(gmx_ABFE): return [gmx.extract_u_nk(file, T=300) for file in gmx_ABFE["complex"]] diff --git a/src/alchemlyb/tests/parsing/test_parquet.py b/src/alchemlyb/tests/parsing/test_parquet.py index a2d788f6..c42cccb6 100644 --- a/src/alchemlyb/tests/parsing/test_parquet.py +++ b/src/alchemlyb/tests/parsing/test_parquet.py @@ -12,12 +12,45 @@ def test_extract_dHdl(dHdl_list, request, tmp_path): new_dHdl = extract_dHdl(str(tmp_path / "dhdl.parquet"), T=300) assert (new_dHdl.columns == dHdl.columns).all() assert (new_dHdl.index == dHdl.index).all() + assert new_dHdl.attrs["temperature"] == 300 + assert new_dHdl.attrs["energy_unit"] == "kT" -@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_n_uk"]) +@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_u_nk"]) def test_extract_dHdl(u_nk_list, request, tmp_path): u_nk = request.getfixturevalue(u_nk_list)[0] u_nk.to_parquet(path=str(tmp_path / "u_nk.parquet"), index=True) new_u_nk = extract_u_nk(str(tmp_path / "u_nk.parquet"), T=300) assert (new_u_nk.columns == u_nk.columns).all() assert (new_u_nk.index == u_nk.index).all() + assert new_u_nk.attrs["temperature"] == 300 + assert new_u_nk.attrs["energy_unit"] == "kT" + + +@pytest.fixture() +def u_nk(gmx_ABFE_complex_u_nk): + return gmx_ABFE_complex_u_nk[0] + + +def test_no_T(u_nk, tmp_path, caplog): + u_nk.attrs = {} + u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True) + extract_u_nk(str(tmp_path / "temp.parquet"), 300) + assert ( + "Serialise the Dataframe with pandas>=2.1 to preserve the metadata." + in caplog.text + ) + + +def test_wrong_T(u_nk, tmp_path, caplog): + u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True) + with pytest.raises(ValueError, match="doesn't match the temperature"): + extract_u_nk(str(tmp_path / "temp.parquet"), 400) + + +def test_metadata_unchanged(u_nk, tmp_path): + u_nk.attrs = {"temperature": 400, "energy_unit": "kcal/mol"} + u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True) + new_u_nk = extract_u_nk(str(tmp_path / "temp.parquet"), 400) + assert new_u_nk.attrs["temperature"] == 400 + assert new_u_nk.attrs["energy_unit"] == "kcal/mol" diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 00e3c030..b59f66b4 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -41,8 +41,8 @@ def u_nk(gmx_benzene_Coulomb_u_nk): @pytest.fixture() -def multi_index_u_nk(gmx_ABFE_complex_n_uk): - return gmx_ABFE_complex_n_uk[0] +def multi_index_u_nk(gmx_ABFE_complex_u_nk): + return gmx_ABFE_complex_u_nk[0] @pytest.fixture() @@ -470,7 +470,7 @@ def test_decorrelate_dhdl_multiple_l(multi_index_dHdl): ) -def test_raise_non_uk(multi_index_dHdl): +def test_raise_nou_nk(multi_index_dHdl): with pytest.raises(ValueError): decorrelate_u_nk( multi_index_dHdl, diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index 82cc83ab..c2dc4b15 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -258,7 +258,7 @@ def test_single_estimator_ti(self, workflow, monkeypatch): summary = workflow.generate_result() assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.51472826028906, 0.1) - def test_unprocessed_n_uk(self, workflow, monkeypatch): + def test_unprocessed_u_nk(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "u_nk_sample_list", None) monkeypatch.setattr(workflow, "estimator", dict()) workflow.estimate()