From 6d2e77ef2a8027cbc2dffb972a9f8667f8683666 Mon Sep 17 00:00:00 2001 From: Anthony Onwuli <30937913+AntObi@users.noreply.github.com> Date: Fri, 26 Jul 2024 10:05:21 +0100 Subject: [PATCH] Fix MPRester tests and access phonon properties from the new API without having `mp-api` installed. (#3950) --- src/pymatgen/ext/matproj.py | 28 +++++++++++++++++++++++++++- tests/ext/test_matproj.py | 30 ++++++++++++++++++------------ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/pymatgen/ext/matproj.py b/src/pymatgen/ext/matproj.py index 71a4cff7a7a..23d867529ff 100644 --- a/src/pymatgen/ext/matproj.py +++ b/src/pymatgen/ext/matproj.py @@ -178,7 +178,7 @@ def get_summary_by_material_id(self, material_id: str, fields: list | None = Non Dict """ get = "_all_fields=True" if fields is None else "_fields=" + ",".join(fields) - return self.request(f"materials/summary/{material_id}/?{get}")[0] + return self.request(f"materials/summary/?{get}", payload={"material_ids": material_id})[0] get_doc = get_summary_by_material_id @@ -349,6 +349,32 @@ def get_entries_in_chemsys(self, elements, *args, **kwargs): return self.get_entries(criteria, *args, **kwargs) + def get_phonon_bandstructure_by_material_id(self, material_id: str): + """Get phonon bandstructure by material_id. + + Args: + material_id (str): Materials Project material_id + + Returns: + PhononBandStructureSymmLine: A phonon band structure. + """ + prop = "ph_bs" + response = self.request(f"materials/phonon/?material_ids={material_id}&_fields={prop}") + return response[0][prop] + + def get_phonon_dos_by_material_id(self, material_id: str): + """Get phonon density of states by material_id. + + Args: + material_id (str): Materials Project material_id + + Returns: + CompletePhononDos: A phonon DOS object. + """ + prop = "ph_dos" + response = self.request(f"materials/phonon/?material_ids={material_id}&_fields={prop}") + return response[0][prop] + class MPRester: """A class to conveniently interface with the new and legacy Materials Project REST interface. diff --git a/tests/ext/test_matproj.py b/tests/ext/test_matproj.py index 8489cf438c3..99e13761f82 100644 --- a/tests/ext/test_matproj.py +++ b/tests/ext/test_matproj.py @@ -16,23 +16,28 @@ from pymatgen.electronic_structure.dos import CompleteDos from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from pymatgen.entries.computed_entries import ComputedEntry -from pymatgen.ext.matproj import MP_LOG_FILE, MPRestError, _MPResterBasic -from pymatgen.ext.matproj_legacy import TaskType, _MPResterLegacy +from pymatgen.ext.matproj import MP_LOG_FILE, _MPResterBasic +from pymatgen.ext.matproj_legacy import MPRestError, TaskType, _MPResterLegacy from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest from pytest import approx from ruamel.yaml import YAML +PMG_MAPI_KEY = SETTINGS.get("PMG_MAPI_KEY", "") +if (10 < len(PMG_MAPI_KEY) <= 20) and "PMG_MAPI_KEY" in SETTINGS: + MP_URL = "https://legacy.materialsproject.org" +elif len(PMG_MAPI_KEY) > 20: + MP_URL = "https://api.materialsproject.org" +else: + MP_URL = "https://materialsproject.org" try: - skip_mprester_tests = requests.get("https://materialsproject.org", timeout=600).status_code != 200 + skip_mprester_tests = requests.get(MP_URL, timeout=600).status_code != 200 except (ModuleNotFoundError, ImportError, requests.exceptions.ConnectionError): # Skip all MPRester tests if some downstream problem on the website, mp-api or whatever. skip_mprester_tests = True -PMG_MAPI_KEY = SETTINGS.get("PMG_MAPI_KEY", "") - @pytest.mark.skipif( skip_mprester_tests or (not 10 < len(PMG_MAPI_KEY) <= 20), @@ -513,8 +518,8 @@ def test_api_key_is_none(self): skip_mprester_tests or (not len(PMG_MAPI_KEY) > 20), reason="PMG_MAPI_KEY environment variable not set or MP API is down.", ) -class TestMPResterNewBasic: - def setup(self): +class TestMPResterNewBasic(PymatgenTest): + def setUp(self): self.rester = _MPResterBasic() def test_get_summary(self): @@ -679,11 +684,12 @@ def test_get_entry_by_material_id(self): # assert isinstance(bs_unif, BandStructure) # assert not isinstance(bs_unif, BandStructureSymmLine) # - # def test_get_phonon_data_by_material_id(self): - # bs = self.rester.get_phonon_bandstructure_by_material_id("mp-661") - # assert isinstance(bs, PhononBandStructureSymmLine) - # dos = self.rester.get_phonon_dos_by_material_id("mp-661") - # assert isinstance(dos, CompletePhononDos) + def test_get_phonon_data_by_material_id(self): + bs = self.rester.get_phonon_bandstructure_by_material_id("mp-661") + assert isinstance(bs, PhononBandStructureSymmLine) + dos = self.rester.get_phonon_dos_by_material_id("mp-661") + assert isinstance(dos, CompletePhononDos) + # ddb_str = self.rester.get_phonon_ddb_by_material_id("mp-661") # assert isinstance(ddb_str, str)