Skip to content

Commit

Permalink
Parallel Joblib Process Entries (#3933)
Browse files Browse the repository at this point in the history
Add joblib backend to process entries in parallel
  • Loading branch information
CompRhys authored Aug 2, 2024
1 parent 940eb60 commit 976942c
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 42 deletions.
139 changes: 97 additions & 42 deletions src/pymatgen/entries/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING, Union, cast

import numpy as np
from joblib import Parallel, delayed
from monty.design_patterns import cached_class
from monty.json import MSONable
from monty.serialization import loadfn
Expand All @@ -30,6 +31,7 @@
)
from pymatgen.io.vasp.sets import MITRelaxSet, MPRelaxSet, VaspInputSet
from pymatgen.util.due import Doi, due
from pymatgen.util.joblib import set_python_warnings, tqdm_joblib

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -538,28 +540,86 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
"""
raise NotImplementedError

def process_entry(self, entry: ComputedEntry, **kwargs) -> ComputedEntry | None:
def process_entry(self, entry: ComputedEntry, inplace: bool = True, **kwargs) -> ComputedEntry | None:
"""Process a single entry with the chosen Corrections. Note
that this method will change the data of the original entry.
Args:
entry: A ComputedEntry object.
inplace (bool): Whether to adjust the entry in place. Defaults to True.
**kwargs: Will be passed to process_entries().
Returns:
An adjusted entry if entry is compatible, else None.
"""
try:
return self.process_entries(entry, **kwargs)[0]
except IndexError:
if not inplace:
entry = copy.deepcopy(entry)

entry = self._process_entry_inplace(entry, **kwargs)

return entry[0] if entry is not None else None

def _process_entry_inplace(
self,
entry: AnyComputedEntry,
clean: bool = True,
on_error: Literal["ignore", "warn", "raise"] = "ignore",
) -> ComputedEntry | None:
"""Process a single entry with the chosen Corrections. Note
that this method will change the data of the original entry.
Args:
entry: A ComputedEntry object.
clean (bool): Whether to remove any previously-applied energy adjustments.
If True, all EnergyAdjustment are removed prior to processing the Entry.
Defaults to True.
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
raises CompatibilityError. Defaults to 'ignore'.
Returns:
An adjusted entry if entry is compatible, else None.
"""
ignore_entry = False
# if clean is True, remove all previous adjustments from the entry
if clean:
entry.energy_adjustments = []

try: # get the energy adjustments
adjustments = self.get_adjustments(entry)
except CompatibilityError as exc:
if on_error == "raise":
raise
if on_error == "warn":
warnings.warn(str(exc))
return None

for ea in adjustments:
# Has this correction already been applied?
if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]:
# we already applied this exact correction. Do nothing.
pass
elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]:
# we already applied a correction with the same name
# but a different value. Something is wrong.
ignore_entry = True
warnings.warn(
f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its "
f"value differs from the value of {ea.value:.3f} calculated here. This "
"Entry will be discarded."
)
else:
# Add the correction to the energy_adjustments list
entry.energy_adjustments.append(ea)

return entry, ignore_entry

def process_entries(
self,
entries: AnyComputedEntry | list[AnyComputedEntry],
clean: bool = True,
verbose: bool = False,
inplace: bool = True,
n_workers: int = 1,
on_error: Literal["ignore", "warn", "raise"] = "ignore",
) -> list[AnyComputedEntry]:
"""Process a sequence of entries with the chosen Compatibility scheme.
Expand All @@ -576,6 +636,7 @@ def process_entries(
verbose (bool): Whether to display progress bar for processing multiple entries.
Defaults to False.
inplace (bool): Whether to adjust input entries in place. Defaults to True.
n_workers (int): Number of workers to use for parallel processing. Defaults to 1.
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
raises CompatibilityError. Defaults to 'ignore'.
Expand All @@ -593,41 +654,28 @@ def process_entries(
if not inplace:
entries = copy.deepcopy(entries)

for entry in tqdm(entries, disable=not verbose):
ignore_entry = False
# if clean is True, remove all previous adjustments from the entry
if clean:
entry.energy_adjustments = []

try: # get the energy adjustments
adjustments = self.get_adjustments(entry)
except CompatibilityError as exc:
if on_error == "raise":
raise
if on_error == "warn":
warnings.warn(str(exc))
continue

for ea in adjustments:
# Has this correction already been applied?
if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]:
# we already applied this exact correction. Do nothing.
pass
elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]:
# we already applied a correction with the same name
# but a different value. Something is wrong.
ignore_entry = True
warnings.warn(
f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its "
f"value differs from the value of {ea.value:.3f} calculated here. This "
"Entry will be discarded."
)
else:
# Add the correction to the energy_adjustments list
entry.energy_adjustments.append(ea)

if not ignore_entry:
processed_entry_list.append(entry)
if n_workers == 1:
for entry in tqdm(entries, disable=not verbose):
result = self._process_entry_inplace(entry, clean, on_error)
if result is None:
continue
entry, ignore_entry = result
if not ignore_entry:
processed_entry_list.append(entry)
elif not inplace:
# set python warnings to ignore otherwise warnings will be printed multiple times
with tqdm_joblib(tqdm(total=len(entries), disable=not verbose)), set_python_warnings("ignore"):
results = Parallel(n_jobs=n_workers)(
delayed(self._process_entry_inplace)(entry, clean, on_error) for entry in entries
)
for result in results:
if result is None:
continue
entry, ignore_entry = result
if not ignore_entry:
processed_entry_list.append(entry)
else:
raise ValueError("Parallel processing is not possible with for 'inplace=True'")

return processed_entry_list

Expand Down Expand Up @@ -1133,7 +1181,9 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
expected_u = float(u_settings.get(symbol, 0))
actual_u = float(calc_u.get(symbol, 0))
if actual_u != expected_u:
raise CompatibilityError(f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3}")
raise CompatibilityError(
f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3} for {entry.as_dict()}"
)
if symbol in u_corrections:
adjustments.append(
CompositionEnergyAdjustment(
Expand Down Expand Up @@ -1450,6 +1500,7 @@ def process_entries(
clean: bool = False,
verbose: bool = False,
inplace: bool = True,
n_workers: int = 1,
on_error: Literal["ignore", "warn", "raise"] = "ignore",
) -> list[AnyComputedEntry]:
"""Process a sequence of entries with the chosen Compatibility scheme.
Expand All @@ -1463,6 +1514,7 @@ def process_entries(
Default is False.
inplace (bool): Whether to modify the entries in place. If False, a copy of the
entries is made and processed. Default is True.
n_workers (int): Number of workers to use for parallel processing. Default is 1.
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
raises CompatibilityError. Defaults to 'ignore'.
Expand All @@ -1480,7 +1532,8 @@ def process_entries(

# pre-process entries with the given solid compatibility class
if self.solid_compat:
entries = self.solid_compat.process_entries(entries, clean=True)
entries = self.solid_compat.process_entries(entries, clean=True, inplace=inplace, n_workers=n_workers)
return [entries]

# when processing single entries, all H2 polymorphs will get assigned the
# same energy
Expand Down Expand Up @@ -1514,7 +1567,9 @@ def process_entries(
h2_entries = sorted(h2_entries, key=lambda e: e.energy_per_atom)
self.h2_energy = h2_entries[0].energy_per_atom # type: ignore[assignment]

return super().process_entries(entries, clean=clean, verbose=verbose, inplace=inplace, on_error=on_error)
return super().process_entries(
entries, clean=clean, verbose=verbose, inplace=inplace, n_workers=n_workers, on_error=on_error
)


def needs_u_correction(
Expand Down
51 changes: 51 additions & 0 deletions src/pymatgen/util/joblib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""This module provides utility functions for getting progress bar with joblib."""

from __future__ import annotations

import contextlib
import os
from typing import TYPE_CHECKING, Any

import joblib

if TYPE_CHECKING:
from collections.abc import Iterator

from tqdm import tqdm


@contextlib.contextmanager
def tqdm_joblib(tqdm_object: tqdm) -> Iterator[None]:
"""Context manager to patch joblib to report into tqdm progress bar given
as argument.
"""

class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
"""This will be called after each batch, to update the progress bar."""
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()


@contextlib.contextmanager
def set_python_warnings(warnings):
"""Context manager to set the PYTHONWARNINGS environment variable to the
given value. This is useful for preventing spam when using parallel processing.
"""
original_warnings = os.environ.get("PYTHONWARNINGS")
os.environ["PYTHONWARNINGS"] = warnings
try:
yield
finally:
if original_warnings is None:
del os.environ["PYTHONWARNINGS"]
else:
os.environ["PYTHONWARNINGS"] = original_warnings
27 changes: 27 additions & 0 deletions tests/entries/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,17 @@ def test_process_entries(self):
entries = self.compat.process_entries([self.entry1, self.entry2, self.entry3, self.entry4])
assert len(entries) == 2

def test_parallel_process_entries(self):
with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"):
entries = self.compat.process_entries(
[self.entry1, self.entry2, self.entry3, self.entry4], inplace=True, n_workers=2
)

entries = self.compat.process_entries(
[self.entry1, self.entry2, self.entry3, self.entry4], inplace=False, n_workers=2
)
assert len(entries) == 2

def test_msonable(self):
compat_dict = self.compat.as_dict()
decoder = MontyDecoder()
Expand Down Expand Up @@ -1879,6 +1890,22 @@ def test_processing_entries_inplace(self):
MaterialsProjectAqueousCompatibility().process_entries(entries, inplace=False)
assert all(e.correction == e_copy.correction for e, e_copy in zip(entries, entries_copy))

def test_parallel_process_entries(self):
hydrate_entry = ComputedEntry(Composition("FeH4O2"), -10) # nH2O = 2
hydrate_entry2 = ComputedEntry(Composition("Li2O2H2"), -10) # nH2O = 0

entry_list = [hydrate_entry, hydrate_entry2]

compat = MaterialsProjectAqueousCompatibility(
o2_energy=-10, h2o_energy=-20, h2o_adjustments=-0.5, solid_compat=None
)

with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"):
entries = compat.process_entries(entry_list, inplace=True, n_workers=2)

entries = compat.process_entries(entry_list, inplace=False, n_workers=2, on_error="raise")
assert len(entries) == 2


class TestAqueousCorrection(TestCase):
def setUp(self):
Expand Down

0 comments on commit 976942c

Please sign in to comment.