Skip to content

Commit

Permalink
Formatting customization for PWInput (#4001)
Browse files Browse the repository at this point in the history
* Enable formatting customization of PWInput files

- Default values set to preserve backwards compatibility

* Fix: Preserve format upon write-then-read

- Preserve indent and max decimal precision
  • Loading branch information
jsukpark authored Aug 15, 2024
1 parent cadcae4 commit b7b7389
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 29 deletions.
118 changes: 92 additions & 26 deletions src/pymatgen/io/pwscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
kpoints_mode="automatic",
kpoints_grid=(1, 1, 1),
kpoints_shift=(0, 0, 0),
format_options=None,
):
"""Initialize a PWSCF input file.
Expand All @@ -60,6 +61,15 @@ def __init__(
kpoints_grid (sequence): The kpoint grid. Default to (1, 1, 1).
kpoints_shift (sequence): The shift for the kpoints. Defaults to
(0, 0, 0).
format_options (dict): Formatting options when writing into a string.
Can be used to specify e.g., the number of decimal places
(including trailing zeros) for real-space coordinate values
(atomic positions, cell parameters). Defaults to None,
in which case the following default values are used
(so as to maintain backwards compatibility):
{"indent": 2, "kpoints_crystal_b_indent": 1,
"coord_decimals": 6, "atomic_mass_decimals": 4,
"kpoints_grid_decimals": 4}.
"""
self.structure = structure
sections = {}
Expand All @@ -84,6 +94,25 @@ def __init__(
self.kpoints_mode = kpoints_mode
self.kpoints_grid = kpoints_grid
self.kpoints_shift = kpoints_shift
self.format_options = {
# Default to 2 spaces for indentation
"indent": 2,
# Default to 1 space for indent in kpoint grid entries
# when kpoints_mode == "crystal_b"
"kpoints_crystal_b_indent": 1,
# Default to 6 decimal places
# for atomic position and cell vector coordinates
"coord_decimals": 6,
# Default to 4 decimal places for atomic mass values
"atomic_mass_decimals": 4,
# Default to 4 decimal places
# for kpoint grid entries
# when kpoints_mode == "crystal_b"
"kpoints_grid_decimals": 4,
}
if format_options is None:
format_options = {}
self.format_options.update(format_options)

def __str__(self):
out = []
Expand Down Expand Up @@ -115,6 +144,7 @@ def to_str(v):
return ".FALSE."
return v

indent = " " * self.format_options["indent"]
for k1 in ["control", "system", "electrons", "ions", "cell"]:
v1 = self.sections[k1]
out.append(f"&{k1.upper()}")
Expand All @@ -123,54 +153,64 @@ def to_str(v):
if isinstance(v1[k2], list):
n = 1
for _ in v1[k2][: len(site_descriptions)]:
sub.append(f" {k2}({n}) = {to_str(v1[k2][n - 1])}")
sub.append(f"{indent}{k2}({n}) = {to_str(v1[k2][n - 1])}")
n += 1
else:
sub.append(f" {k2} = {to_str(v1[k2])}")
sub.append(f"{indent}{k2} = {to_str(v1[k2])}")
if k1 == "system":
if "ibrav" not in self.sections[k1]:
sub.append(" ibrav = 0")
sub.append(f"{indent}ibrav = 0")
if "nat" not in self.sections[k1]:
sub.append(f" nat = {len(self.structure)}")
sub.append(f"{indent}nat = {len(self.structure)}")
if "ntyp" not in self.sections[k1]:
sub.append(f" ntyp = {len(site_descriptions)}")
sub.append(f"{indent}ntyp = {len(site_descriptions)}")
sub.append("/")
out.append(",\n".join(sub))

out.append("ATOMIC_SPECIES")
prec = self.format_options["atomic_mass_decimals"]
for k, v in sorted(site_descriptions.items(), key=lambda i: i[0]):
e = re.match(r"[A-Z][a-z]?", k)[0]
p = v if self.pseudo is not None else v["pseudo"]
out.append(f" {k} {Element(e).atomic_mass:.4f} {p}")
out.append(f"{indent}{k} {Element(e).atomic_mass:.{prec}f} {p}")

out.append("ATOMIC_POSITIONS crystal")
prec = self.format_options["coord_decimals"]
if self.pseudo is not None:
for site in self.structure:
out.append(f" {site.specie} {site.a:.6f} {site.b:.6f} {site.c:.6f}")
pos_str = [f"{site.specie}"]
pos_str.extend([f"{v:.{prec}f}" for v in site.frac_coords])
out.append(f"{indent}{' '.join(pos_str)}")
else:
for site in self.structure:
name = None
for k, v in sorted(site_descriptions.items(), key=lambda i: i[0]):
if v == site.properties:
name = k
out.append(f" {name} {site.a:.6f} {site.b:.6f} {site.c:.6f}")
pos_str = [f"{name}"]
pos_str.extend([f"{v:.{prec}f}" for v in site.frac_coords])
out.append(f"{indent}{' '.join(pos_str)}")

out.append(f"K_POINTS {self.kpoints_mode}")
if self.kpoints_mode == "automatic":
kpt_str = [f"{i}" for i in self.kpoints_grid]
kpt_str.extend([f"{i}" for i in self.kpoints_shift])
out.append(f" {' '.join(kpt_str)}")
out.append(f"{indent}{' '.join(kpt_str)}")
elif self.kpoints_mode == "crystal_b":
out.append(f" {len(self.kpoints_grid)}")
kpt_indent = " " * self.format_options["kpoints_crystal_b_indent"]
out.append(f"{kpt_indent}{len(self.kpoints_grid)}")
prec = self.format_options["kpoints_grid_decimals"]
for i in range(len(self.kpoints_grid)):
kpt_str = [f"{entry:.4f}" for entry in self.kpoints_grid[i]]
out.append(f" {' '.join(kpt_str)}")
kpt_str = [f"{entry:.{prec}f}" for entry in self.kpoints_grid[i]]
out.append(f"{kpt_indent}{' '.join(kpt_str)}")
elif self.kpoints_mode == "gamma":
pass

out.append("CELL_PARAMETERS angstrom")
prec = self.format_options["coord_decimals"]
for vec in self.structure.lattice.matrix:
out.append(f" {vec[0]:f} {vec[1]:f} {vec[2]:f}")
vec_str = [f"{v:.{prec}f}" for v in vec]
out.append(f"{indent}{' '.join(vec_str)}")
return "\n".join(out)

def as_dict(self):
Expand All @@ -187,6 +227,7 @@ def as_dict(self):
"kpoints_mode": self.kpoints_mode,
"kpoints_grid": self.kpoints_grid,
"kpoints_shift": self.kpoints_shift,
"format_options": self.format_options,
}

@classmethod
Expand All @@ -211,6 +252,7 @@ def from_dict(cls, dct: dict) -> Self:
kpoints_mode=dct["kpoints_mode"],
kpoints_grid=dct["kpoints_grid"],
kpoints_shift=dct["kpoints_shift"],
format_options=dct["format_options"],
)

def write_file(self, filename):
Expand Down Expand Up @@ -247,7 +289,7 @@ def from_str(cls, string: str) -> Self:
Returns:
PWInput object
"""
lines = list(clean_lines(string.splitlines()))
lines = list(clean_lines(string.splitlines(), rstrip_only=True))

def input_mode(line):
if line[0] == "&":
Expand Down Expand Up @@ -282,17 +324,19 @@ def input_mode(line):
kpoints_grid = (1, 1, 1)
kpoints_shift = (0, 0, 0)
coords_are_cartesian = False
format_options = {}

for line in lines:
mode = input_mode(line)
if mode is None:
pass
elif mode[0] == "sections":
section = mode[1]
if match := re.match(r"(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line):
key = match[1].strip()
key_ = match[2].strip()
val = match[3].strip().rstrip(",")
if match := re.match(r"^(\s*)(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line):
format_options["indent"] = len(match[1])
key = match[2].strip()
key_ = match[3].strip()
val = match[4].strip().rstrip(",")
if key_ != "":
if sections[section].get(key) is None:
val_ = [0.0] * 20 # MAX NTYP DEFINITION
Expand All @@ -306,8 +350,9 @@ def input_mode(line):
sections[section][key] = PWInput.proc_val(key, val)

elif mode[0] == "pseudo":
if match := re.match(r"(\w+\d*[\+-]?)\s+(\d*.\d*)\s+(.*)", line):
pseudo[match[1].strip()] = match[3].strip()
if match := re.match(r"^(\s*)(\w+\d*[\+-]?)\s+(\d*.\d*)\s+(.*)", line):
format_options["indent"] = len(match[1])
pseudo[match[2].strip()] = match[4].strip()

elif mode[0] == "kpoints":
if match := re.match(r"(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)", line):
Expand All @@ -317,19 +362,39 @@ def input_mode(line):
kpoints_mode = mode[1]

elif mode[0] == "structure":
m_l = re.match(r"(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line)
m_p = re.match(r"(\w+\d*[\+-]?)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line)
m_l = re.match(r"^(\s*)(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line)
m_p = re.match(r"^(\s*)(\w+\d*[\+-]?)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line)
if m_l:
format_options["indent"] = len(m_l[1])
lattice += [
float(m_l[1]),
float(m_l[2]),
float(m_l[3]),
float(m_l[4]),
]
decimals = max(
# length of decimal digits; 0 if no decimal digits
(len(dec[1]) if len(dec := v.split(".")) == 2 else 0)
for v in (m_l[2], m_l[3], m_l[4])
)
format_options["coord_decimals"] = max(
format_options.get("coord_decimals", 0),
decimals,
)

elif m_p:
site_properties["pseudo"].append(pseudo[m_p[1]])
species.append(m_p[1])
coords += [[float(m_p[2]), float(m_p[3]), float(m_p[4])]]
format_options["indent"] = len(m_p[1])
site_properties["pseudo"].append(pseudo[m_p[2]])
species.append(m_p[2])
coords += [[float(m_p[3]), float(m_p[4]), float(m_p[5])]]
decimals = max(
# length of decimal digits; 0 if no decimal digits
(len(dec[1]) if len(dec := v.split(".")) == 2 else 0)
for v in (m_p[3], m_p[4], m_p[5])
)
format_options["coord_decimals"] = max(
format_options.get("coord_decimals", 0),
decimals,
)

if mode[1] == "angstrom":
coords_are_cartesian = True
Expand All @@ -352,6 +417,7 @@ def input_mode(line):
kpoints_mode=kpoints_mode,
kpoints_grid=kpoints_grid,
kpoints_shift=kpoints_shift,
format_options=format_options,
)

@staticmethod
Expand Down
13 changes: 10 additions & 3 deletions src/pymatgen/util/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,30 @@
__date__ = "Sep 23, 2011"


def clean_lines(string_list, remove_empty_lines=True) -> Generator[str, None, None]:
def clean_lines(
string_list,
remove_empty_lines=True,
rstrip_only=False,
) -> Generator[str, None, None]:
"""Strips whitespace, carriage returns and empty lines from a list of strings.
Args:
string_list: List of strings
remove_empty_lines: Set to True to skip lines which are empty after
stripping.
rstrip_only: Set to True to strip trailing whitespaces only (i.e.,
to retain leading whitespaces). Defaults to False.
Yields:
list: clean strings with no whitespaces.
list: clean strings with no whitespaces. If rstrip_only == True,
clean strings with no trailing whitespaces.
"""
for s in string_list:
clean_s = s
if "#" in s:
ind = s.index("#")
clean_s = s[:ind]
clean_s = clean_s.strip()
clean_s = clean_s.rstrip() if rstrip_only else clean_s.strip()
if (not remove_empty_lines) or clean_s != "":
yield clean_s

Expand Down
Loading

0 comments on commit b7b7389

Please sign in to comment.