From b7b73898d9b9af3ac6d76da33019d4a915fc44fc Mon Sep 17 00:00:00 2001 From: Eugene Park Date: Thu, 15 Aug 2024 11:44:57 -0500 Subject: [PATCH] Formatting customization for `PWInput` (#4001) * Enable formatting customization of PWInput files - Default values set to preserve backwards compatibility * Fix: Preserve format upon write-then-read - Preserve indent and max decimal precision --- src/pymatgen/io/pwscf.py | 118 ++++++++++++++++++++++++++-------- src/pymatgen/util/io_utils.py | 13 +++- tests/io/test_pwscf.py | 113 ++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 29 deletions(-) diff --git a/src/pymatgen/io/pwscf.py b/src/pymatgen/io/pwscf.py index abcbf2bb8c2..b790eed3ca6 100644 --- a/src/pymatgen/io/pwscf.py +++ b/src/pymatgen/io/pwscf.py @@ -37,6 +37,7 @@ def __init__( kpoints_mode="automatic", kpoints_grid=(1, 1, 1), kpoints_shift=(0, 0, 0), + format_options=None, ): """Initialize a PWSCF input file. @@ -60,6 +61,15 @@ def __init__( kpoints_grid (sequence): The kpoint grid. Default to (1, 1, 1). kpoints_shift (sequence): The shift for the kpoints. Defaults to (0, 0, 0). + format_options (dict): Formatting options when writing into a string. + Can be used to specify e.g., the number of decimal places + (including trailing zeros) for real-space coordinate values + (atomic positions, cell parameters). Defaults to None, + in which case the following default values are used + (so as to maintain backwards compatibility): + {"indent": 2, "kpoints_crystal_b_indent": 1, + "coord_decimals": 6, "atomic_mass_decimals": 4, + "kpoints_grid_decimals": 4}. """ self.structure = structure sections = {} @@ -84,6 +94,25 @@ def __init__( self.kpoints_mode = kpoints_mode self.kpoints_grid = kpoints_grid self.kpoints_shift = kpoints_shift + self.format_options = { + # Default to 2 spaces for indentation + "indent": 2, + # Default to 1 space for indent in kpoint grid entries + # when kpoints_mode == "crystal_b" + "kpoints_crystal_b_indent": 1, + # Default to 6 decimal places + # for atomic position and cell vector coordinates + "coord_decimals": 6, + # Default to 4 decimal places for atomic mass values + "atomic_mass_decimals": 4, + # Default to 4 decimal places + # for kpoint grid entries + # when kpoints_mode == "crystal_b" + "kpoints_grid_decimals": 4, + } + if format_options is None: + format_options = {} + self.format_options.update(format_options) def __str__(self): out = [] @@ -115,6 +144,7 @@ def to_str(v): return ".FALSE." return v + indent = " " * self.format_options["indent"] for k1 in ["control", "system", "electrons", "ions", "cell"]: v1 = self.sections[k1] out.append(f"&{k1.upper()}") @@ -123,54 +153,64 @@ def to_str(v): if isinstance(v1[k2], list): n = 1 for _ in v1[k2][: len(site_descriptions)]: - sub.append(f" {k2}({n}) = {to_str(v1[k2][n - 1])}") + sub.append(f"{indent}{k2}({n}) = {to_str(v1[k2][n - 1])}") n += 1 else: - sub.append(f" {k2} = {to_str(v1[k2])}") + sub.append(f"{indent}{k2} = {to_str(v1[k2])}") if k1 == "system": if "ibrav" not in self.sections[k1]: - sub.append(" ibrav = 0") + sub.append(f"{indent}ibrav = 0") if "nat" not in self.sections[k1]: - sub.append(f" nat = {len(self.structure)}") + sub.append(f"{indent}nat = {len(self.structure)}") if "ntyp" not in self.sections[k1]: - sub.append(f" ntyp = {len(site_descriptions)}") + sub.append(f"{indent}ntyp = {len(site_descriptions)}") sub.append("/") out.append(",\n".join(sub)) out.append("ATOMIC_SPECIES") + prec = self.format_options["atomic_mass_decimals"] for k, v in sorted(site_descriptions.items(), key=lambda i: i[0]): e = re.match(r"[A-Z][a-z]?", k)[0] p = v if self.pseudo is not None else v["pseudo"] - out.append(f" {k} {Element(e).atomic_mass:.4f} {p}") + out.append(f"{indent}{k} {Element(e).atomic_mass:.{prec}f} {p}") out.append("ATOMIC_POSITIONS crystal") + prec = self.format_options["coord_decimals"] if self.pseudo is not None: for site in self.structure: - out.append(f" {site.specie} {site.a:.6f} {site.b:.6f} {site.c:.6f}") + pos_str = [f"{site.specie}"] + pos_str.extend([f"{v:.{prec}f}" for v in site.frac_coords]) + out.append(f"{indent}{' '.join(pos_str)}") else: for site in self.structure: name = None for k, v in sorted(site_descriptions.items(), key=lambda i: i[0]): if v == site.properties: name = k - out.append(f" {name} {site.a:.6f} {site.b:.6f} {site.c:.6f}") + pos_str = [f"{name}"] + pos_str.extend([f"{v:.{prec}f}" for v in site.frac_coords]) + out.append(f"{indent}{' '.join(pos_str)}") out.append(f"K_POINTS {self.kpoints_mode}") if self.kpoints_mode == "automatic": kpt_str = [f"{i}" for i in self.kpoints_grid] kpt_str.extend([f"{i}" for i in self.kpoints_shift]) - out.append(f" {' '.join(kpt_str)}") + out.append(f"{indent}{' '.join(kpt_str)}") elif self.kpoints_mode == "crystal_b": - out.append(f" {len(self.kpoints_grid)}") + kpt_indent = " " * self.format_options["kpoints_crystal_b_indent"] + out.append(f"{kpt_indent}{len(self.kpoints_grid)}") + prec = self.format_options["kpoints_grid_decimals"] for i in range(len(self.kpoints_grid)): - kpt_str = [f"{entry:.4f}" for entry in self.kpoints_grid[i]] - out.append(f" {' '.join(kpt_str)}") + kpt_str = [f"{entry:.{prec}f}" for entry in self.kpoints_grid[i]] + out.append(f"{kpt_indent}{' '.join(kpt_str)}") elif self.kpoints_mode == "gamma": pass out.append("CELL_PARAMETERS angstrom") + prec = self.format_options["coord_decimals"] for vec in self.structure.lattice.matrix: - out.append(f" {vec[0]:f} {vec[1]:f} {vec[2]:f}") + vec_str = [f"{v:.{prec}f}" for v in vec] + out.append(f"{indent}{' '.join(vec_str)}") return "\n".join(out) def as_dict(self): @@ -187,6 +227,7 @@ def as_dict(self): "kpoints_mode": self.kpoints_mode, "kpoints_grid": self.kpoints_grid, "kpoints_shift": self.kpoints_shift, + "format_options": self.format_options, } @classmethod @@ -211,6 +252,7 @@ def from_dict(cls, dct: dict) -> Self: kpoints_mode=dct["kpoints_mode"], kpoints_grid=dct["kpoints_grid"], kpoints_shift=dct["kpoints_shift"], + format_options=dct["format_options"], ) def write_file(self, filename): @@ -247,7 +289,7 @@ def from_str(cls, string: str) -> Self: Returns: PWInput object """ - lines = list(clean_lines(string.splitlines())) + lines = list(clean_lines(string.splitlines(), rstrip_only=True)) def input_mode(line): if line[0] == "&": @@ -282,6 +324,7 @@ def input_mode(line): kpoints_grid = (1, 1, 1) kpoints_shift = (0, 0, 0) coords_are_cartesian = False + format_options = {} for line in lines: mode = input_mode(line) @@ -289,10 +332,11 @@ def input_mode(line): pass elif mode[0] == "sections": section = mode[1] - if match := re.match(r"(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line): - key = match[1].strip() - key_ = match[2].strip() - val = match[3].strip().rstrip(",") + if match := re.match(r"^(\s*)(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line): + format_options["indent"] = len(match[1]) + key = match[2].strip() + key_ = match[3].strip() + val = match[4].strip().rstrip(",") if key_ != "": if sections[section].get(key) is None: val_ = [0.0] * 20 # MAX NTYP DEFINITION @@ -306,8 +350,9 @@ def input_mode(line): sections[section][key] = PWInput.proc_val(key, val) elif mode[0] == "pseudo": - if match := re.match(r"(\w+\d*[\+-]?)\s+(\d*.\d*)\s+(.*)", line): - pseudo[match[1].strip()] = match[3].strip() + if match := re.match(r"^(\s*)(\w+\d*[\+-]?)\s+(\d*.\d*)\s+(.*)", line): + format_options["indent"] = len(match[1]) + pseudo[match[2].strip()] = match[4].strip() elif mode[0] == "kpoints": if match := re.match(r"(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)", line): @@ -317,19 +362,39 @@ def input_mode(line): kpoints_mode = mode[1] elif mode[0] == "structure": - m_l = re.match(r"(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) - m_p = re.match(r"(\w+\d*[\+-]?)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) + m_l = re.match(r"^(\s*)(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) + m_p = re.match(r"^(\s*)(\w+\d*[\+-]?)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) if m_l: + format_options["indent"] = len(m_l[1]) lattice += [ - float(m_l[1]), float(m_l[2]), float(m_l[3]), + float(m_l[4]), ] + decimals = max( + # length of decimal digits; 0 if no decimal digits + (len(dec[1]) if len(dec := v.split(".")) == 2 else 0) + for v in (m_l[2], m_l[3], m_l[4]) + ) + format_options["coord_decimals"] = max( + format_options.get("coord_decimals", 0), + decimals, + ) elif m_p: - site_properties["pseudo"].append(pseudo[m_p[1]]) - species.append(m_p[1]) - coords += [[float(m_p[2]), float(m_p[3]), float(m_p[4])]] + format_options["indent"] = len(m_p[1]) + site_properties["pseudo"].append(pseudo[m_p[2]]) + species.append(m_p[2]) + coords += [[float(m_p[3]), float(m_p[4]), float(m_p[5])]] + decimals = max( + # length of decimal digits; 0 if no decimal digits + (len(dec[1]) if len(dec := v.split(".")) == 2 else 0) + for v in (m_p[3], m_p[4], m_p[5]) + ) + format_options["coord_decimals"] = max( + format_options.get("coord_decimals", 0), + decimals, + ) if mode[1] == "angstrom": coords_are_cartesian = True @@ -352,6 +417,7 @@ def input_mode(line): kpoints_mode=kpoints_mode, kpoints_grid=kpoints_grid, kpoints_shift=kpoints_shift, + format_options=format_options, ) @staticmethod diff --git a/src/pymatgen/util/io_utils.py b/src/pymatgen/util/io_utils.py index b8fa1148f17..b40a6edb50d 100644 --- a/src/pymatgen/util/io_utils.py +++ b/src/pymatgen/util/io_utils.py @@ -20,23 +20,30 @@ __date__ = "Sep 23, 2011" -def clean_lines(string_list, remove_empty_lines=True) -> Generator[str, None, None]: +def clean_lines( + string_list, + remove_empty_lines=True, + rstrip_only=False, +) -> Generator[str, None, None]: """Strips whitespace, carriage returns and empty lines from a list of strings. Args: string_list: List of strings remove_empty_lines: Set to True to skip lines which are empty after stripping. + rstrip_only: Set to True to strip trailing whitespaces only (i.e., + to retain leading whitespaces). Defaults to False. Yields: - list: clean strings with no whitespaces. + list: clean strings with no whitespaces. If rstrip_only == True, + clean strings with no trailing whitespaces. """ for s in string_list: clean_s = s if "#" in s: ind = s.index("#") clean_s = s[:ind] - clean_s = clean_s.strip() + clean_s = clean_s.rstrip() if rstrip_only else clean_s.strip() if (not remove_empty_lines) or clean_s != "": yield clean_s diff --git a/tests/io/test_pwscf.py b/tests/io/test_pwscf.py index 76870670178..ee7f959ede2 100644 --- a/tests/io/test_pwscf.py +++ b/tests/io/test_pwscf.py @@ -402,6 +402,119 @@ def test_write_and_read_str_with_oxidation(self): pw_str = str(pw) assert pw_str.strip() == str(PWInput.from_str(pw_str)).strip() + def test_custom_decimal_precision(self): + struct = self.get_structure("Li2O") + pw = PWInput( + struct, + control={"calculation": "scf", "pseudo_dir": "./"}, + pseudo={ + "Li+": "Li.pbe-n-kjpaw_psl.0.1.UPF", + "O2-": "O.pbe-n-kjpaw_psl.0.1.UPF", + }, + system={"ecutwfc": 50}, + format_options={"coord_decimals": 9, "indent": 0}, + ) + expected = """&CONTROL +calculation = 'scf', +pseudo_dir = './', +/ +&SYSTEM +ecutwfc = 50, +ibrav = 0, +nat = 3, +ntyp = 2, +/ +&ELECTRONS +/ +&IONS +/ +&CELL +/ +ATOMIC_SPECIES +Li+ 6.9410 Li.pbe-n-kjpaw_psl.0.1.UPF +O2- 15.9994 O.pbe-n-kjpaw_psl.0.1.UPF +ATOMIC_POSITIONS crystal +O2- 0.000000000 0.000000000 0.000000000 +Li+ 0.750178290 0.750178290 0.750178290 +Li+ 0.249821710 0.249821710 0.249821710 +K_POINTS automatic +1 1 1 0 0 0 +CELL_PARAMETERS angstrom +2.917388570 0.097894370 1.520004660 +0.964634060 2.755035610 1.520004660 +0.133206350 0.097894430 3.286917710 +""" + assert str(pw).strip() == expected.strip() + + def test_custom_decimal_precision_kpoint_grid_crystal_b(self): + struct = self.get_structure("Li2O") + struct.remove_oxidation_states() + kpoints = [[0.0, 0.0, 0.0], [0.0, 0.5, 0.5], [0.5, 0.0, 0.0], [0.0, 0.0, 0.5], [0.5, 0.5, 0.5]] + pw = PWInput( + struct, + control={"calculation": "scf", "pseudo_dir": "./"}, + pseudo={ + "Li": "Li.pbe-n-kjpaw_psl.0.1.UPF", + "O": "O.pbe-n-kjpaw_psl.0.1.UPF", + }, + system={"ecutwfc": 50}, + kpoints_mode="crystal_b", + kpoints_grid=kpoints, + format_options={"kpoints_crystal_b_indent": 2}, + ) + expected = """ +&CONTROL + calculation = 'scf', + pseudo_dir = './', +/ +&SYSTEM + ecutwfc = 50, + ibrav = 0, + nat = 3, + ntyp = 2, +/ +&ELECTRONS +/ +&IONS +/ +&CELL +/ +ATOMIC_SPECIES + Li 6.9410 Li.pbe-n-kjpaw_psl.0.1.UPF + O 15.9994 O.pbe-n-kjpaw_psl.0.1.UPF +ATOMIC_POSITIONS crystal + O 0.000000 0.000000 0.000000 + Li 0.750178 0.750178 0.750178 + Li 0.249822 0.249822 0.249822 +K_POINTS crystal_b + 5 + 0.0000 0.0000 0.0000 + 0.0000 0.5000 0.5000 + 0.5000 0.0000 0.0000 + 0.0000 0.0000 0.5000 + 0.5000 0.5000 0.5000 +CELL_PARAMETERS angstrom + 2.917389 0.097894 1.520005 + 0.964634 2.755036 1.520005 + 0.133206 0.097894 3.286918 +""" + assert str(pw).strip() == expected.strip() + + def test_custom_decimal_precision_write_and_read_str(self): + struct = self.get_structure("Li2O") + pw = PWInput( + struct, + control={"calculation": "scf", "pseudo_dir": "./"}, + pseudo={ + "Li+": "Li.pbe-n-kjpaw_psl.0.1.UPF", + "O2-": "O.pbe-n-kjpaw_psl.0.1.UPF", + }, + system={"ecutwfc": 50}, + format_options={"coord_decimals": 9}, + ) + pw_str = str(pw) + assert pw_str.strip() == str(PWInput.from_str(pw_str)).strip() + class TestPWOutput(PymatgenTest): def setUp(self):