Skip to content

Commit

Permalink
Make AimsSpeciesFile a dataclass (#4054)
Browse files Browse the repository at this point in the history
* Modify for proper dataclass

remove commented code and __init__ function

* refactor and use Self return type for from_... methods

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
tpurcell90 and janosh authored Sep 8, 2024
1 parent 149e115 commit 326beb9
Showing 1 changed file with 24 additions and 52 deletions.
76 changes: 24 additions & 52 deletions src/pymatgen/io/aims/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,70 +689,43 @@ def from_dict(cls, dct: dict[str, Any]) -> Self:
return cls(_parameters=decoded["parameters"])


@dataclass
class AimsSpeciesFile:
"""An FHI-aims single species' defaults file."""
"""An FHI-aims single species' defaults file.
def __init__(self, data: str, label: str | None = None) -> None:
"""
Args:
data (str): A string of the complete species defaults file
label (str): A string representing the name of species
"""
self.data = data
self.label = label
Attributes:
data (str): A string of the complete species defaults file
label (str): A string representing the name of species
"""

data: str = ""
label: str | None = None

def __post_init__(self) -> None:
"""Set default label"""
if self.label is None:
for line in data.splitlines():
for line in self.data.splitlines():
if "species" in line:
self.label = line.split()[1]

def __eq__(self, other: object) -> bool:
"""True if two species are equal."""
if not isinstance(other, AimsSpeciesFile):
return NotImplemented
return self.data == other.data

def __lt__(self, other: object) -> bool:
"""True if self is less than other."""
if not isinstance(other, AimsSpeciesFile):
return NotImplemented
return self.data < other.data

def __le__(self, other: object) -> bool:
"""True if self is less than or equal to other."""
if not isinstance(other, AimsSpeciesFile):
return NotImplemented
return self.data <= other.data

def __gt__(self, other: object) -> bool:
"""True if self is greater than other."""
if not isinstance(other, AimsSpeciesFile):
return NotImplemented
return self.data > other.data

def __ge__(self, other: object) -> bool:
"""True if self is greater than or equal to other."""
if not isinstance(other, AimsSpeciesFile):
return NotImplemented
return self.data >= other.data

@classmethod
def from_file(cls, filename: str, label: str | None = None) -> AimsSpeciesFile:
def from_file(cls, filename: str, label: str | None = None) -> Self:
"""Initialize from file.
Args:
filename (str): The filename of the species' defaults file
label (str): A string representing the name of species
Returns:
The AimsSpeciesFile instance
AimsSpeciesFile
"""
with zopen(filename, mode="rt") as file:
return cls(file.read(), label)
return cls(data=file.read(), label=label)

@classmethod
def from_element_and_basis_name(
cls, element: str, basis: str, *, species_dir: str | Path | None = None, label: str | None = None
) -> AimsSpeciesFile:
) -> Self:
"""Initialize from element and basis names.
Args:
Expand All @@ -763,7 +736,7 @@ def from_element_and_basis_name(
then equal to element
Returns:
an AimsSpeciesFile instance
AimsSpeciesFile
"""
# check if element is in the Periodic Table (+ Emptium)
if element != "Emptium":
Expand Down Expand Up @@ -795,25 +768,24 @@ def from_element_and_basis_name(
f"Can't find the species' defaults file for {element} in {basis} basis set. Paths tried: {paths_to_try}"
)

def __str__(self):
def __str__(self) -> str:
"""String representation of the species' defaults file"""
return re.sub(r"^ *species +\w+", f" species {self.label}", self.data, flags=re.MULTILINE)

@property
def element(self) -> str:
match = re.search(r"^ *species +(\w+)", self.data, flags=re.MULTILINE)
if match is None:
raise ValueError("Can't find element in species' defaults file")
return match.group(1)
if match := re.search(r"^ *species +(\w+)", self.data, flags=re.MULTILINE):
return match[1]
raise ValueError("Can't find element in species' defaults file")

def as_dict(self) -> dict[str, Any]:
"""Dictionary representation of the species' defaults file."""
return {"label": self.label, "data": self.data, "@module": type(self).__module__, "@class": type(self).__name__}

@classmethod
def from_dict(cls, dct: dict[str, Any]) -> AimsSpeciesFile:
def from_dict(cls, dct: dict[str, Any]) -> Self:
"""Deserialization of the AimsSpeciesFile object"""
return AimsSpeciesFile(data=dct["data"], label=dct["label"])
return cls(**dct)


class SpeciesDefaults(list, MSONable):
Expand Down

0 comments on commit 326beb9

Please sign in to comment.