Skip to content

Commit

Permalink
change Trajectory._check_site_props: Assertion->ValueError
Browse files Browse the repository at this point in the history
replace raise AssertionError in tests with equiv assert statement
  • Loading branch information
janosh committed Aug 8, 2024
1 parent ce360f4 commit 61b02a5
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 51 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.6
rev: v0.5.7
hooks:
- id: ruff
args: [--fix, --unsafe-fixes]
Expand Down Expand Up @@ -66,6 +66,6 @@ repos:
args: [--drop-empty-cells, --keep-output]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.374
rev: v1.1.375
hooks:
- id: pyright
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ exclude_also = [
"if settings.DEBUG",
"if typing.TYPE_CHECKING:",
"pragma: no cover",
"raise AssertionError",
"raise NotImplementedError",
"show_plot",
]
Expand Down
10 changes: 5 additions & 5 deletions src/pymatgen/core/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,15 +690,15 @@ def _check_site_props(self, site_props: SitePropsType | None) -> None:
if isinstance(site_props, dict):
site_props = [site_props]
elif len(site_props) != len(self):
raise AssertionError(
f"Size of the site properties {len(site_props)} does not equal to the number of frames {len(self)}"
raise ValueError(
f"Size of the site properties {len(site_props)} does not equal the number of frames {len(self)}"
)

n_sites = len(self.coords[0])
for dct in site_props:
for key, val in dct.items():
assert len(val) == n_sites, (
f"Size of site property {key} {len(val)}) does not equal to the "
f"Size of site property {key} {len(val)}) does not equal the "
f"number of sites in the structure {n_sites}."
)

Expand All @@ -708,8 +708,8 @@ def _check_frame_props(self, frame_props: list[dict] | None) -> None:
return

if len(frame_props) != len(self):
raise AssertionError(
f"Size of the frame properties {len(frame_props)} does not equal to the number of frames {len(self)}"
raise ValueError(
f"Size of the frame properties {len(frame_props)} does not equal the number of frames {len(self)}"
)

def _get_site_props(self, frames: ValidIndex) -> SitePropsType | None:
Expand Down
4 changes: 2 additions & 2 deletions src/pymatgen/io/multiwfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def parse_cp(lines: list[str]) -> tuple[str | None, dict[str, Any]]:
# Figure out what kind of critical-point we're dealing with
if "(3,-3)" in lines_split[0]:
cp_type = "atom"
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k not in ["connected_bond_paths"]}
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k != "connected_bond_paths"}
elif "(3,-1)" in lines_split[0]:
cp_type = "bond"
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k not in ["ele_info"]}
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k != "ele_info"}
elif "(3,+1)" in lines_split[0]:
cp_type = "ring"
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k not in ["connected_bond_paths", "ele_info"]}
Expand Down
42 changes: 18 additions & 24 deletions tests/core/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,51 +78,45 @@ def test_slice(self):
sliced_traj = self.traj[2:99:3]
sliced_traj_from_structs = Trajectory.from_structures(self.structures[2:99:3])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(
sliced_traj_from_structs
), f"{len(sliced_traj)=} != {len(sliced_traj_from_structs)=}"
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))

sliced_traj = self.traj[:-4:2]
sliced_traj_from_structs = Trajectory.from_structures(self.structures[:-4:2])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all(sliced_traj[idx] == sliced_traj_from_structs[idx] for idx in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(
sliced_traj_from_structs
), f"{len(sliced_traj)=} != {len(sliced_traj_from_structs)=}"
assert all(sliced_traj[idx] == sliced_traj_from_structs[idx] for idx in range(len(sliced_traj)))

sliced_traj = self.traj_mols[:2]
sliced_traj_from_mols = Trajectory.from_molecules(self.molecules[:2])

if len(sliced_traj) == len(sliced_traj_from_mols):
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(sliced_traj_from_mols), f"{len(sliced_traj)=} != {len(sliced_traj_from_mols)=}"
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))

sliced_traj = self.traj_mols[:-2]
sliced_traj_from_mols = Trajectory.from_molecules(self.molecules[:-2])

if len(sliced_traj) == len(sliced_traj_from_mols):
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(sliced_traj_from_mols), f"{len(sliced_traj)=} != {len(sliced_traj_from_mols)=}"
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))

def test_list_slice(self):
sliced_traj = self.traj[[10, 30, 70]]
sliced_traj_from_structs = Trajectory.from_structures([self.structures[i] for i in [10, 30, 70]])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(
sliced_traj_from_structs
), f"{len(sliced_traj)=} != {len(sliced_traj_from_structs)=}"
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))

sliced_traj = self.traj_mols[[1, 3]]
sliced_traj_from_mols = Trajectory.from_molecules([self.molecules[i] for i in [1, 3]])

if len(sliced_traj) == len(sliced_traj_from_mols):
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))
else:
raise AssertionError
assert len(sliced_traj) == len(sliced_traj_from_mols), f"{len(sliced_traj)=} != {len(sliced_traj_from_mols)=}"
assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj)))

def test_conversion(self):
# Convert to displacements and back, and then check structures.
Expand Down
12 changes: 4 additions & 8 deletions tests/entries/test_mixing_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,11 +1406,9 @@ def test_state_gga_2_scan_same(self, mixing_scheme_no_compat, ms_gga_2_scan_same
if entry.entry_id in ["r2scan-4", "r2scan-6"]:
assert entry.correction == 3
assert entry.parameters["run_type"] == "R2SCAN"
elif entry.entry_id == "gga-4":
raise AssertionError("Entry gga-4 should have been discarded")
elif entry.entry_id == "gga-6":
raise AssertionError("Entry gga-6 should have been discarded")
else:
assert entry.entry_id != "gga-4", f"{entry.entry_id=} should have been discarded"
assert entry.entry_id != "gga-6", f"{entry.entry_id=} should have been discarded"
assert entry.correction == 0, f"{entry.entry_id}"
assert entry.parameters["run_type"] == "GGA"

Expand Down Expand Up @@ -1455,9 +1453,8 @@ def test_state_gga_2_scan_diff_match(self, mixing_scheme_no_compat, ms_gga_2_sca
assert entry.correction == 3
elif entry.entry_id == "r2scan-7":
assert entry.correction == 15
elif entry.entry_id == "gga-4":
raise AssertionError(f"Entry {entry.entry_id} should have been discarded")
else:
assert entry.entry_id != "gga-4", f"{entry.entry_id=} should have been discarded"
assert entry.correction == 0, f"{entry.entry_id}"
assert entry.parameters["run_type"] == "GGA"

Expand Down Expand Up @@ -1501,9 +1498,8 @@ def test_state_gga_2_scan_diff_nomatch(self, mixing_scheme_no_compat, ms_gga_2_s
if entry.entry_id == "r2scan-4":
assert entry.correction == 3
assert entry.parameters["run_type"] == "R2SCAN"
elif entry.entry_id in ["gga-4", "r2scan-8"]:
raise AssertionError(f"Entry {entry.entry_id} should have been discarded")
else:
assert entry.entry_id not in ("gga-4", "r2scan-8"), f"{entry.entry_id=} should have been discarded"
assert entry.correction == 0, f"{entry.entry_id}"
assert entry.parameters["run_type"] == "GGA"

Expand Down
18 changes: 9 additions & 9 deletions tests/io/test_multiwfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_parse_single_cp():
name1, desc1 = parse_cp(contents)

contents_split = [line.split() for line in contents]
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k not in ["connected_bond_paths"]}
conditionals = {k: v for k, v in QTAIM_CONDITIONALS.items() if k != "connected_bond_paths"}
name2, desc2 = extract_info_from_cp_text(contents_split, "atom", conditionals)

assert name1 == name2
Expand Down Expand Up @@ -170,22 +170,22 @@ def test_add_atoms():
separated = separate_cps_by_type(all_descs)

# Test ValueErrors
mol_minatom = Molecule(["O"], [[0.0, 0.0, 0.0]])
mol_min_atom = Molecule(["O"], [[0.0, 0.0, 0.0]])

with pytest.raises(ValueError, match=r"bond CP"):
add_atoms(mol_minatom, separated)
add_atoms(mol_min_atom, separated)

sep_minbonds = copy.deepcopy(separated)
sep_minbonds["bond"] = {k: separated["bond"][k] for k in ["1_bond", "2_bond"]}
sep_min_bonds = copy.deepcopy(separated)
sep_min_bonds["bond"] = {k: separated["bond"][k] for k in ["1_bond", "2_bond"]}

with pytest.raises(ValueError, match=r"ring CP"):
add_atoms(mol, sep_minbonds)
add_atoms(mol, sep_min_bonds)

sep_minrings = copy.deepcopy(separated)
sep_minrings["ring"] = {k: separated["ring"][k] for k in ["13_ring", "14_ring"]}
sep_min_rings = copy.deepcopy(separated)
sep_min_rings["ring"] = {k: separated["ring"][k] for k in ["13_ring", "14_ring"]}

with pytest.raises(ValueError, match=r"cage CP"):
add_atoms(mol, sep_minrings)
add_atoms(mol, sep_min_rings)

# Test distance-based metric
modified = add_atoms(mol, separated, bond_atom_criterion="distance")
Expand Down

0 comments on commit 61b02a5

Please sign in to comment.