Skip to content

Commit

Permalink
Add Robot.getattrs_alleq method
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Sep 15, 2024
1 parent abc2c9a commit 7df0777
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 99 deletions.
2 changes: 2 additions & 0 deletions abipy/abilab.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from abipy.eph.rta import RtaFile, RtaRobot
from abipy.eph.transportfile import TransportFile
from abipy.eph.gstore import GstoreFile
from abipy.eph.gpath import GpathFile
from abipy.wannier90 import WoutFile, AbiwanFile, AbiwanRobot
from abipy.electrons.lobster import CoxpFile, ICoxpFile, LobsterDoscarFile, LobsterInput, LobsterAnalyzer

Expand Down Expand Up @@ -173,6 +174,7 @@ def _straceback():
("A2F.nc", A2fFile),
("SIGEPH.nc", SigEPhFile),
("GSTORE.nc", GstoreFile),
("GPATH.nc", GpathFile),
("TRANSPORT.nc",TransportFile),
("RTA.nc",RtaFile),
("V1SYM.nc", V1symFile),
Expand Down
53 changes: 40 additions & 13 deletions abipy/abio/robots.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,23 +613,50 @@ def _repr_html_(self) -> str:
"""Integration with jupyter_ notebooks."""
return '<ol start="0">\n{}\n</ol>'.format("\n".join("<li>%s</li>" % label for label, abifile in self.items()))

def getattr_alleq(self, aname : str):
def getattrs_alleq(self, *aname_args) -> list:
"""
Return the value of attribute aname.
Return list of attribute values for each attribute name in *aname_args.
"""
return [self.getattr_alleq(aname) for aname in aname_args]

def getattr_alleq(self, aname: str):
"""
Return the value of attribute aname. Try firs in self then in self.r
Raises ValueError if value is not the same across all the files in the robot.
"""
val1 = getattr(self.abifiles[0], aname)

for abifile in self.abifiles[1:]:
val2 = getattr(abifile, aname)
if isinstance(val1, (str, int, float)):
eq = val1 == val2
elif isinstance(val1, np.ndarray):
eq = np.allclose(val1, val2)
if not eq:
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")

return val1
def get_obj_list(what: str):
if what == "abifiles":
return self.abifiles
elif what == "r":
return [abifile.r for abifile in self.abifiles]

raise ValueError(f"Invalid {what=}")

err_msg = []

for what in ["abifiles", "r"]:
objs = get_obj_list(what)

try:
val1 = getattr(objs[0], aname)
except AttributeError as exc:
err_msg.append(str(exc))
continue

for obj in objs[1:]:
val2 = getattr(obj, aname)
if isinstance(val1, (str, int, float)):
eq = val1 == val2
elif isinstance(val1, np.ndarray):
eq = np.allclose(val1, val2)
if not eq:
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")

return val1

if err_msg:
raise ValueError("\n".join(err_msg))

@property
def abifiles(self) -> list:
Expand Down
Loading

0 comments on commit 7df0777

Please sign in to comment.