Skip to content

Commit

Permalink
Breaking: return sum of Species with matching Element in `Composi…
Browse files Browse the repository at this point in the history
…tion.__getitem__` (#3427)

* don't rename ElementTree import

* return sum of Species with matching element in Composition.__getitem__

fixes MP2020Compatibility not applying anion correction when passing in ComputedEntry with oxidation states

* test_composition.py add test_getitem

* improve coverage in test_process_entry_with_oxidation_state with 2nd example for ComputedStructureEntry
  • Loading branch information
janosh authored Oct 26, 2023
1 parent b4aedf0 commit bd13ea4
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2
rev: v0.1.3
hooks:
- id: ruff
args: [--fix]
Expand Down
9 changes: 7 additions & 2 deletions pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def __init__(self, *args, strict: bool = False, **kwargs) -> None:
def __getitem__(self, key: SpeciesLike) -> float:
try:
sp = get_el_sp(key)
return self._data.get(sp, 0)
if isinstance(sp, Species):
return self._data.get(sp, 0)
# sp is Element or str
return sum(
val for key, val in self._data.items() if getattr(key, "symbol", key) == getattr(sp, "symbol", sp)
)
except ValueError as exc:
raise KeyError(f"Invalid {key=}") from exc

Expand All @@ -153,7 +158,7 @@ def __contains__(self, key) -> bool:
sp = get_el_sp(key)
if isinstance(sp, Species):
return sp in self._data
# Element or str
# key is Element or str
return any(sp.symbol == s.symbol for s in self._data)
except ValueError as exc:
raise TypeError(f"Invalid {key=} for Composition") from exc
Expand Down
23 changes: 16 additions & 7 deletions tests/core/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def test_in(self):
assert Element("Fe") not in comp
assert Species("Fe2+") not in comp

def test_getitem(self):
comp = Composition({"Li+": 1, "Mn3+": 2, "O2-": 4, "Li": 1})
assert comp["Li"] == 2
assert comp["Li+"] == 1
assert comp["Mn3+"] == 2
assert comp["Mn"] == 2
assert comp["O2-"] == 4
assert comp["O"] == 4

def test_hill_formula(self):
c = Composition("CaCO3")
assert c.hill_formula == "C Ca O3"
Expand Down Expand Up @@ -269,8 +278,8 @@ def test_reduced_formula(self):
assert Composition("H6CN").get_integer_formula_and_factor(iupac_ordering=True)[0] == "CNH6"

# test rounding
c = Composition({"Na": 2 - Composition.amount_tolerance / 2, "Cl": 2})
assert c.reduced_formula == "NaCl"
comp = Composition({"Na": 2 - Composition.amount_tolerance / 2, "Cl": 2})
assert comp.reduced_formula == "NaCl"

def test_integer_formula(self):
correct_reduced_formulas = [
Expand Down Expand Up @@ -299,8 +308,8 @@ def test_integer_formula(self):
def test_num_atoms(self):
correct_num_atoms = [20, 10, 7, 8, 20, 75, 2, 3]

all_natoms = [c.num_atoms for c in self.comps]
assert all_natoms == correct_num_atoms
all_n_atoms = [c.num_atoms for c in self.comps]
assert all_n_atoms == correct_num_atoms

def test_weight(self):
correct_weights = [
Expand Down Expand Up @@ -360,7 +369,7 @@ def test_from_weight_dict(self):
for el in c1.elements:
assert c1[el] == approx(c2[el], abs=1e-3)

def test_tofrom_weight_dict(self):
def test_to_from_weight_dict(self):
for comp in self.comps:
c2 = Composition().from_weight_dict(comp.to_weight_dict)
comp.almost_equals(c2)
Expand Down Expand Up @@ -520,8 +529,8 @@ def test_negative_compositions(self):
# test species
c1 = Composition({"Mg": 1, "Mg2+": -1}, allow_negative=True)
assert c1.num_atoms == 2
assert c1.element_composition == Composition()
assert c1.average_electroneg == 1.31
assert c1.element_composition == Composition("Mg-1", allow_negative=True)
assert c1.average_electroneg == 0.655

def test_special_formulas(self):
special_formulas = {
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def test_distance_and_image(self):
dist_old, jimage_old = get_distance_and_image_old(site1, site2)
dist_new, jimage_new = site1.distance_and_image(site2)
assert dist_old - dist_new > -1e-8, "New distance algo should give smaller answers!"
assert not (abs(dist_old - dist_new) < 1e-8) ^ (
jimage_old == jimage_new
).all(), "If old dist == new dist, images must be the same!"
assert (
not (abs(dist_old - dist_new) < 1e-8) ^ (jimage_old == jimage_new).all()
), "If old dist == new dist, images must be the same!"
latt = Lattice.from_parameters(3.0, 3.1, 10.0, 2.96, 2.0, 1.0)
site = PeriodicSite("Fe", [0.1, 0.1, 0.1], latt)
site2 = PeriodicSite("Fe", [0.99, 0.99, 0.99], latt)
Expand Down
42 changes: 36 additions & 6 deletions tests/entries/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,18 +1000,48 @@ def test_check_potcar(self):
def test_process_entry_with_oxidation_state(self):
from pymatgen.core.periodic_table import Species

entry = ComputedEntry(
{Species("Fe2+"): 2, Species("O2-"): 3},
-1,
parameters={"is_hubbard": True, "hubbards": {"Fe": 5.3, "O": 0}, "run_type": "GGA+U"},
)
params = {"is_hubbard": True, "hubbards": {"Fe": 5.3, "O": 0}, "run_type": "GGA+U"}
entry = ComputedEntry({Species("Fe2+"): 2, Species("O2-"): 3}, -1, parameters=params)

# Test that MaterialsProject2020Compatibility can process entries with oxidation states
# https://github.com/materialsproject/pymatgen/issues/3154
compat = MaterialsProject2020Compatibility(check_potcar=False)
[processed_entry] = compat.process_entries(entry, clean=True, inplace=False)
processed_entry = compat.process_entry(entry, clean=True, inplace=False)

assert len(processed_entry.energy_adjustments) == 2
assert processed_entry.energy_adjustments[0].name == "MP2020 anion correction (oxide)"
assert processed_entry.energy_adjustments[1].name == "MP2020 GGA/GGA+U mixing correction (Fe)"
assert processed_entry.correction == approx(-6.572999)
assert processed_entry.energy == approx(-1 + -6.572999)

# for https://github.com/materialsproject/pymatgen/issues/3425
frac_coords = [
[0.5, 0.5, 0.3797505],
[0.0, 0.0, 0.6202495],
[0.5, 0.5, 0.8632525],
[0.0, 0.0, 0.1367475],
[0.5, 0.0, 0.3608245],
[0.0, 0.5, 0.0985135],
[0.5, 0.0, 0.9014865],
[0.0, 0.5, 0.6391755],
]
lattice = [
[2.86877900, 0.00000000e00, 1.75662051e-16],
[-2.83779749e-16, 4.63447500e00, 2.83779749e-16],
[0.00000000e00, 0.00000000e00, 5.83250700e00],
]
species = ["Li+", "Li+", "Mn3+", "Mn3+", "O2-", "O2-", "O2-", "O2-"]
li_mn_o = Structure(lattice, species, frac_coords)

params = {"hubbards": {"Mn": 3.9, "O": 0, "Li": 0}, "run_type": "GGA+U"}
cse = ComputedStructureEntry(li_mn_o, -58.97, parameters=params)
processed_entry = compat.process_entry(cse, clean=True, inplace=False)

assert len(processed_entry.energy_adjustments) == 2
assert processed_entry.energy_adjustments[0].name == "MP2020 anion correction (oxide)"
assert processed_entry.energy_adjustments[1].name == "MP2020 GGA/GGA+U mixing correction (Mn)"
assert processed_entry.correction == approx(-6.084)
assert processed_entry.energy == approx(-58.97 + -6.084)


class TestMITCompatibility(unittest.TestCase):
Expand Down
8 changes: 4 additions & 4 deletions tests/io/exciting/test_inputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import xml.etree.ElementTree as ET
from xml.etree import ElementTree

from numpy.testing import assert_allclose

Expand Down Expand Up @@ -121,7 +121,7 @@ def test_writebandstr(self):
"S",
"R",
]
root = ET.fromstring(bandstr)
root = ElementTree.fromstring(bandstr)
for plot1d in root.iter("plot1d"):
for point in plot1d.iter("point"):
coord.append([float(i) for i in point.get("coord").split()])
Expand Down Expand Up @@ -159,8 +159,8 @@ def test_paramdict(self):

# read reference file
filepath = f"{TEST_FILES_DIR}/input_exciting2.xml"
tree = ET.parse(filepath)
tree = ElementTree.parse(filepath)
root = tree.getroot()
ref_string = ET.tostring(root, encoding="unicode")
ref_string = ElementTree.tostring(root, encoding="unicode")

assert ref_string.strip() == test_string.strip()
17 changes: 3 additions & 14 deletions tests/transformations/test_standard_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,21 +338,10 @@ def test_no_oxidation(self):

def test_symmetrized_structure(self):
trafo = OrderDisorderedStructureTransformation(symmetrized_structures=True)
c = []
sp = []
c.append([0.5, 0.5, 0.5])
sp.append("Si4+")
c.append([0.45, 0.45, 0.45])
sp.append({"Si4+": 0.5})
c.append([0.56, 0.56, 0.56])
sp.append({"Si4+": 0.5})
c.append([0.25, 0.75, 0.75])
sp.append({"Si4+": 0.5})
c.append([0.75, 0.25, 0.25])
sp.append({"Si4+": 0.5})
latt = Lattice.cubic(5)
struct = Structure(latt, sp, c)
test_site = PeriodicSite("Si4+", c[2], latt)
coords = [[0.5, 0.5, 0.5], [0.45, 0.45, 0.45], [0.56, 0.56, 0.56], [0.25, 0.75, 0.75], [0.75, 0.25, 0.25]]
struct = Structure(latt, [{"Si4+": 1}, *[{"Si4+": 0.5}] * 4], coords)
test_site = PeriodicSite("Si4+", coords[2], latt)
struct = SymmetrizedStructure(struct, "not_real", [0, 1, 1, 2, 2], ["a", "b", "b", "c", "c"])
output = trafo.apply_transformation(struct)
assert test_site in output
Expand Down

0 comments on commit bd13ea4

Please sign in to comment.