Skip to content

Commit

Permalink
Fix MPRester tests and access phonon properties from the new API with…
Browse files Browse the repository at this point in the history
…out having `mp-api` installed. (#3950)
  • Loading branch information
AntObi authored Jul 26, 2024
1 parent 1bdca30 commit 6d2e77e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
28 changes: 27 additions & 1 deletion src/pymatgen/ext/matproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_summary_by_material_id(self, material_id: str, fields: list | None = Non
Dict
"""
get = "_all_fields=True" if fields is None else "_fields=" + ",".join(fields)
return self.request(f"materials/summary/{material_id}/?{get}")[0]
return self.request(f"materials/summary/?{get}", payload={"material_ids": material_id})[0]

get_doc = get_summary_by_material_id

Expand Down Expand Up @@ -349,6 +349,32 @@ def get_entries_in_chemsys(self, elements, *args, **kwargs):

return self.get_entries(criteria, *args, **kwargs)

def get_phonon_bandstructure_by_material_id(self, material_id: str):
"""Get phonon bandstructure by material_id.
Args:
material_id (str): Materials Project material_id
Returns:
PhononBandStructureSymmLine: A phonon band structure.
"""
prop = "ph_bs"
response = self.request(f"materials/phonon/?material_ids={material_id}&_fields={prop}")
return response[0][prop]

def get_phonon_dos_by_material_id(self, material_id: str):
"""Get phonon density of states by material_id.
Args:
material_id (str): Materials Project material_id
Returns:
CompletePhononDos: A phonon DOS object.
"""
prop = "ph_dos"
response = self.request(f"materials/phonon/?material_ids={material_id}&_fields={prop}")
return response[0][prop]


class MPRester:
"""A class to conveniently interface with the new and legacy Materials Project REST interface.
Expand Down
30 changes: 18 additions & 12 deletions tests/ext/test_matproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,28 @@
from pymatgen.electronic_structure.dos import CompleteDos
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.ext.matproj import MP_LOG_FILE, MPRestError, _MPResterBasic
from pymatgen.ext.matproj_legacy import TaskType, _MPResterLegacy
from pymatgen.ext.matproj import MP_LOG_FILE, _MPResterBasic
from pymatgen.ext.matproj_legacy import MPRestError, TaskType, _MPResterLegacy
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.phonon.dos import CompletePhononDos
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest
from pytest import approx
from ruamel.yaml import YAML

PMG_MAPI_KEY = SETTINGS.get("PMG_MAPI_KEY", "")
if (10 < len(PMG_MAPI_KEY) <= 20) and "PMG_MAPI_KEY" in SETTINGS:
MP_URL = "https://legacy.materialsproject.org"
elif len(PMG_MAPI_KEY) > 20:
MP_URL = "https://api.materialsproject.org"
else:
MP_URL = "https://materialsproject.org"
try:
skip_mprester_tests = requests.get("https://materialsproject.org", timeout=600).status_code != 200
skip_mprester_tests = requests.get(MP_URL, timeout=600).status_code != 200

except (ModuleNotFoundError, ImportError, requests.exceptions.ConnectionError):
# Skip all MPRester tests if some downstream problem on the website, mp-api or whatever.
skip_mprester_tests = True

PMG_MAPI_KEY = SETTINGS.get("PMG_MAPI_KEY", "")


@pytest.mark.skipif(
skip_mprester_tests or (not 10 < len(PMG_MAPI_KEY) <= 20),
Expand Down Expand Up @@ -513,8 +518,8 @@ def test_api_key_is_none(self):
skip_mprester_tests or (not len(PMG_MAPI_KEY) > 20),
reason="PMG_MAPI_KEY environment variable not set or MP API is down.",
)
class TestMPResterNewBasic:
def setup(self):
class TestMPResterNewBasic(PymatgenTest):
def setUp(self):
self.rester = _MPResterBasic()

def test_get_summary(self):
Expand Down Expand Up @@ -679,11 +684,12 @@ def test_get_entry_by_material_id(self):
# assert isinstance(bs_unif, BandStructure)
# assert not isinstance(bs_unif, BandStructureSymmLine)
#
# def test_get_phonon_data_by_material_id(self):
# bs = self.rester.get_phonon_bandstructure_by_material_id("mp-661")
# assert isinstance(bs, PhononBandStructureSymmLine)
# dos = self.rester.get_phonon_dos_by_material_id("mp-661")
# assert isinstance(dos, CompletePhononDos)
def test_get_phonon_data_by_material_id(self):
bs = self.rester.get_phonon_bandstructure_by_material_id("mp-661")
assert isinstance(bs, PhononBandStructureSymmLine)
dos = self.rester.get_phonon_dos_by_material_id("mp-661")
assert isinstance(dos, CompletePhononDos)

# ddb_str = self.rester.get_phonon_ddb_by_material_id("mp-661")
# assert isinstance(ddb_str, str)

Expand Down

0 comments on commit 6d2e77e

Please sign in to comment.