Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xiki-tempula committed Dec 30, 2023
1 parent 862343b commit c46f349
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The rules for this file:
*/*/2023 hl2500

* 2.2.0

Changes
- For pandas>=2.1, metadata will be loaded from the parquet file (issue #331, PR #326).

Enhancements
- Add a TI estimator using gaussian quadrature to calculate the free energy.
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python
- numpy
- pandas
- pandas>=2.1
- pymbar>=4
- scipy
- scikit-learn
Expand Down
39 changes: 36 additions & 3 deletions src/alchemlyb/parsing/parquet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
import pandas as pd
from loguru import logger

from . import _init_attrs


@_init_attrs
def _check_metadata(path: str, T: float) -> pd.DataFrame:
"""
Check if the metadata is included in the Dataframe and has the correct
temperature.
Parameters
----------
path : str
Path to parquet file to extract dataframe from.
T : float
Temperature in Kelvin of the simulations.
Returns
-------
DataFrame
"""
df = pd.read_parquet(path)
if "temperature" not in df.attrs:
logger.warning(
f"No temperature metadata found in {path}. "
f"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
)
df.attrs["temperature"] = T
df.attrs["energy_unit"] = "kT"
else:
if df.attrs["temperature"] != T:
raise ValueError(
f"Temperature in the input ({T}) doesn't match the temperature "
f"in the dataframe ({df.attrs['temperature']})."
)
return df


def extract_u_nk(path, T):
r"""Return reduced potentials `u_nk` (unit: kT) from a pandas parquet file.
Expand Down Expand Up @@ -36,7 +69,7 @@ def extract_u_nk(path, T):
.. versionadded:: 2.1.0
"""
u_nk = pd.read_parquet(path)
u_nk = _check_metadata(path, T)
columns = list(u_nk.columns)
if isinstance(columns[0], str) and columns[0][0] == "(":
new_columns = []
Expand Down Expand Up @@ -81,4 +114,4 @@ def extract_dHdl(path, T):
.. versionadded:: 2.1.0
"""
return pd.read_parquet(path)
return _check_metadata(path, T)
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def gmx_ABFE():


@pytest.fixture
def gmx_ABFE_complex_n_uk(gmx_ABFE):
def gmx_ABFE_complex_u_nk(gmx_ABFE):
return [gmx.extract_u_nk(file, T=300) for file in gmx_ABFE["complex"]]


Expand Down
35 changes: 34 additions & 1 deletion src/alchemlyb/tests/parsing/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,45 @@ def test_extract_dHdl(dHdl_list, request, tmp_path):
new_dHdl = extract_dHdl(str(tmp_path / "dhdl.parquet"), T=300)
assert (new_dHdl.columns == dHdl.columns).all()
assert (new_dHdl.index == dHdl.index).all()
assert new_dHdl.attrs["temperature"] == 300
assert new_dHdl.attrs["energy_unit"] == "kT"


@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_n_uk"])
@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_u_nk"])
def test_extract_dHdl(u_nk_list, request, tmp_path):
u_nk = request.getfixturevalue(u_nk_list)[0]
u_nk.to_parquet(path=str(tmp_path / "u_nk.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "u_nk.parquet"), T=300)
assert (new_u_nk.columns == u_nk.columns).all()
assert (new_u_nk.index == u_nk.index).all()
assert new_u_nk.attrs["temperature"] == 300
assert new_u_nk.attrs["energy_unit"] == "kT"


@pytest.fixture()
def u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


def test_no_T(u_nk, tmp_path, caplog):
u_nk.attrs = {}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
extract_u_nk(str(tmp_path / "temp.parquet"), 300)
assert (
"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
in caplog.text
)


def test_wrong_T(u_nk, tmp_path, caplog):
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
with pytest.raises(ValueError, match="doesn't match the temperature"):
extract_u_nk(str(tmp_path / "temp.parquet"), 400)


def test_metadata_unchanged(u_nk, tmp_path):
u_nk.attrs = {"temperature": 400, "energy_unit": "kcal/mol"}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "temp.parquet"), 400)
assert new_u_nk.attrs["temperature"] == 400
assert new_u_nk.attrs["energy_unit"] == "kcal/mol"
6 changes: 3 additions & 3 deletions src/alchemlyb/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def u_nk(gmx_benzene_Coulomb_u_nk):


@pytest.fixture()
def multi_index_u_nk(gmx_ABFE_complex_n_uk):
return gmx_ABFE_complex_n_uk[0]
def multi_index_u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


@pytest.fixture()
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_decorrelate_dhdl_multiple_l(multi_index_dHdl):
)


def test_raise_non_uk(multi_index_dHdl):
def test_raise_nou_nk(multi_index_dHdl):
with pytest.raises(ValueError):
decorrelate_u_nk(
multi_index_dHdl,
Expand Down
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_single_estimator_ti(self, workflow, monkeypatch):
summary = workflow.generate_result()
assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.51472826028906, 0.1)

def test_unprocessed_n_uk(self, workflow, monkeypatch):
def test_unprocessed_u_nk(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, "u_nk_sample_list", None)
monkeypatch.setattr(workflow, "estimator", dict())
workflow.estimate()
Expand Down

0 comments on commit c46f349

Please sign in to comment.