Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shoubhikraj committed Dec 25, 2023
1 parent c1a90fe commit f2e5f65
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 49 deletions.
93 changes: 46 additions & 47 deletions autode/opt/coordinates/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import copy

import numpy as np
from enum import Enum
import itertools
from typing import Any, Optional, Type, List, TYPE_CHECKING
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -98,7 +97,9 @@ def __init__(self, *args: Any):
def append(self, item: Primitive) -> None:
"""Append an item to this set of primitives"""
assert isinstance(item, Primitive), "Must be a Primitive type!"
super().append(item)
# prevent duplicate primitives
if item not in self:
super().append(item)

@property
def B(self) -> np.ndarray:
Expand Down Expand Up @@ -251,7 +252,6 @@ def _populate_all(self, x: np.ndarray) -> None:

def build_pic_from_species(
mol: "Species",
aux_bonds=False,
) -> AnyPIC:
"""
Build a set of primitives from the species, using the graph as
Expand All @@ -261,14 +261,13 @@ def build_pic_from_species(
Args:
mol:
aux_bonds:
Returns:
(AnyPIC): The set of primitive internals
"""
pic = AnyPIC()
core_graph = _get_connected_graph_from_species(mol)
_add_bonds_from_species(pic, mol, core_graph, aux_bonds=aux_bonds)
_add_bonds_from_species(pic, mol, core_graph)
_add_angles_from_species(pic, mol, core_graph)
_add_dihedrals_from_species(pic, mol, core_graph)
return pic
Expand Down Expand Up @@ -340,7 +339,6 @@ def _add_bonds_from_species(
pic: AnyPIC,
mol: "Species",
core_graph: "MolecularGraph",
aux_bonds: bool = False,
):
"""
Modify the supplied AnyPIC instance in-place by adding bonds, from the
Expand All @@ -350,7 +348,6 @@ def _add_bonds_from_species(
pic: The AnyPIC instance (modified in-place)
mol: The species object
core_graph: The connectivity graph
aux_bonds: Whether to add auxiliary bonds (< 2.5 * covalent radii sum)
"""
n = 0
for i, j in sorted(core_graph.edges):
Expand All @@ -365,23 +362,14 @@ def _add_bonds_from_species(
pic.append(PrimitiveDistance(i, j))
assert n == mol.constraints.n_distance

if not aux_bonds:
return None

# add auxiliary bonds if specified
for i, j in itertools.combinations(range(mol.n_atoms), r=2):
if core_graph.has_edge(i, j):
continue
if mol.distance(i, j) < 2.5 * mol.eqm_bond_distance(i, j):
pic.append(PrimitiveDistance(i, j))
return None


def _add_angles_from_species(
pic: AnyPIC,
mol: "Species",
core_graph: "MolecularGraph",
lin_thresh=Angle(170, "deg"),
lin_thresh: Angle = Angle(170, "deg"),
) -> None:
"""
Modify the set of primitives in-place by adding angles, from the
Expand All @@ -393,27 +381,36 @@ def _add_angles_from_species(
core_graph (MolecularGraph): The connectivity graph
lin_thresh (Angle): The angle threshold for linearity
"""
lin_thresh = lin_thresh.to("rad")

def get_ref_atom(a, b, c):
def get_ref_atom(a, b, c, bonded=False):
"""get a reference atom for a-b-c linear angle"""
# all atoms in 4 A radius except a, b, c
near_atoms = [
idx
for idx in range(mol.n_atoms)
if mol.distance(b, idx) < Distance(4.0, "ang")
and idx not in (a, b, c)
]
# only check bonded atoms if requested
if bonded:
near_atoms = list(core_graph.neighbors(b))
near_atoms.remove(a)
near_atoms.remove(c)

# otherwise get all atoms in 4 A radius except a, b, c
else:
near_atoms = [
idx
for idx in range(mol.n_atoms)
if mol.distance(b, idx) < Distance(4.0, "ang")
and idx not in (a, b, c)
]

# get atoms closest to perpendicular
deviations_from_90 = {}
for atom in near_atoms:
i_b_a = mol.angle(atom, b, a)
if i_b_a > lin_thresh or i_b_a < (Angle(180, "deg") - lin_thresh):
if i_b_a > lin_thresh or i_b_a < (np.pi - lin_thresh):
continue
i_b_c = mol.angle(atom, b, c)
if i_b_c > lin_thresh or i_b_c < (Angle(180, "deg") - lin_thresh):
if i_b_c > lin_thresh or i_b_c < (np.pi - lin_thresh):
continue
deviation_a = abs(i_b_a - Angle(90, "deg"))
deviation_c = abs(i_b_c - Angle(90, "deg"))
deviation_a = abs(i_b_a - np.pi / 2)
deviation_c = abs(i_b_c - np.pi / 2)
avg_dev = (deviation_a + deviation_c) / 2
deviations_from_90[atom] = avg_dev

Expand All @@ -424,23 +421,20 @@ def get_ref_atom(a, b, c):

for o in range(mol.n_atoms):
for n, m in itertools.combinations(core_graph.neighbors(o), r=2):
# avoid almost linear angles
if mol.angle(m, o, n) < lin_thresh:
pic.append(PrimitiveBondAngle(m=m, o=o, n=n))
else:
# if central atom is connected to another, no need to include
other_neighbours = list(core_graph.neighbors(o))
other_neighbours.remove(m)
other_neighbours.remove(n)
if any(
mol.angle(m, o, x) < lin_thresh for x in other_neighbours
) or any(
mol.angle(n, o, x) < lin_thresh for x in other_neighbours
):
# If central atom is connected to another atom, then the
# linear angle is skipped and instead an out-of-plane
# (improper dihedral) coordinate is used
r = get_ref_atom(m, o, n, bonded=True)
if r is not None:
pic.append(PrimitiveDihedralAngle(m, r, o, n))
continue

# for linear bends, ideally a reference atom is needed
r = get_ref_atom(m, o, n)
# Otherwise, we use a nearby (< 4.0 A) reference atom to
# define two orthogonal linear bends
r = get_ref_atom(m, o, n, bonded=False)
if r is not None:
pic.append(
PrimitiveLinearAngle(m, o, n, r, LinearBendType.BEND)
Expand All @@ -450,7 +444,10 @@ def get_ref_atom(a, b, c):
m, o, n, r, LinearBendType.COMPLEMENT
)
)
else: # these use dummy atom for reference

# For completely linear molecules (CO2), there will be no such
# reference atoms, so use dummy atoms instead
else:
pic.append(
PrimitiveDummyLinearAngle(m, o, n, LinearBendType.BEND)
)
Expand All @@ -467,22 +464,24 @@ def _add_dihedrals_from_species(
pic: AnyPIC,
mol: "Species",
core_graph: "MolecularGraph",
lin_thresh=Angle(170, "deg"),
lin_thresh: Angle = Angle(170, "deg"),
) -> None:
"""
Modify the set of primitives in-place by adding dihedrals (torsions),
from the connectivity graph supplied
Args:
pic: The AnyPIC instance (modified in-place)
mol: The species
core_graph: The connectivity graph
pic (AnyPIC): The AnyPIC instance (modified in-place)
mol (Species): The species
core_graph (MolecularGraph): The connectivity graph
lin_thresh (Angle): The threshold for linearity
"""
# no dihedrals possible with less than 4 atoms
if mol.n_atoms < 4:
return

zero_angle_thresh = Angle(180, "deg") - lin_thresh
lin_thresh = lin_thresh.to("rad")
zero_angle_thresh = np.pi - lin_thresh

def is_dihedral_well_defined(w, x, y, z):
"""A dihedral is well-defined if any angle is not linear"""
Expand Down
36 changes: 34 additions & 2 deletions tests/test_opt/test_coordiantes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
h2o2_mol,
feco5_mol,
cumulene_mol,
acetylene_mol,
)
from autode.utils import work_in_tmp_dir
from autode.atoms import Atom
from autode.species.molecule import Molecule
from autode.values import Angle
Expand Down Expand Up @@ -831,10 +833,31 @@ def test_pic_generation_linear_angle_ref():
assert not any(ic1 == ic2 for ic1, ic2 in itertools.combinations(pic, r=2))
# check that linear bends use reference atoms, not dummy
assert not any(isinstance(ic, PrimitiveDummyLinearAngle) for ic in pic)
# there should not be any dihedral for this geometry
assert not any(isinstance(ic, PrimitiveDihedralAngle) for ic in pic)
# for C-Fe-C, one out-of-plane dihedral should be present
assert PrimitiveDihedralAngle(3, 5, 2, 1) in pic
# check degrees of freedom = 3N - 6
_ = pic(m.coordinates.flatten())
assert np.linalg.matrix_rank(pic.B) == 3 * m.n_atoms - 6


def test_pic_generation_linear_angle_dummy():
# acetylene molecule
mol = acetylene_mol()
pic = build_pic_from_species(mol)

# there should not be any usual bond angles
assert not any(isinstance(ic, PrimitiveBondAngle) for ic in pic)
# there should not be any linear angles with reference atom
assert not any(isinstance(ic, PrimitiveLinearAngle) for ic in pic)
# there should be linear angles with dummy
assert any(isinstance(ic, PrimitiveDummyLinearAngle) for ic in pic)

# degrees of freedom = 3N - 5 for linear molecules
_ = pic(mol.coordinates.flatten())
assert np.linalg.matrix_rank(pic.B) == 3 * mol.n_atoms - 5


@work_in_tmp_dir()
def test_pic_generation_disjoint_graph():
# the algorithm should fully connect the graph
xyz_string = (
Expand Down Expand Up @@ -872,6 +895,15 @@ def test_pic_generation_disjoint_graph():
# the other distance between fragments is 2, 3 which should not be connected
assert PrimitiveDistance(2, 3) not in pic
assert PrimitiveBondAngle(1, 2, 3) not in pic
# check degrees of freedom = 3N - 6
_ = pic(mol.coordinates.flatten())
assert np.linalg.matrix_rank(pic.B) == 3 * mol.n_atoms - 6

# if the bond between 2, 3 is made into a constraint, it will generate angles
mol.constraints.distance = {(2, 3): mol.distance(2, 3)}
pic = build_pic_from_species(mol)
assert ConstrainedPrimitiveDistance(2, 3, mol.distance(2, 3)) in pic
assert PrimitiveBondAngle(1, 2, 3) in pic


def test_pic_generation_chain_dihedrals():
Expand Down

0 comments on commit f2e5f65

Please sign in to comment.