Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend selection #106

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 22 additions & 35 deletions spyrmsd/graph.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
import numpy as np
import sys

from spyrmsd import constants

## Checking and initializing the available backends
graph_backends = []

try:
from spyrmsd.graphs.gt import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
)
import spyrmsd.graphs.gt
graph_backends.append("graph_tool")

except ImportError:
try:
from spyrmsd.graphs.nx import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
)
except ImportError:
raise ImportError("graph_tool or NetworkX libraries not found.")

__all__ = [
"graph_from_adjacency_matrix",
"match_graphs",
"vertex_property",
"num_vertices",
"num_edges",
"lattice",
"cycle",
"adjacency_matrix_from_atomic_coordinates",
]

import numpy as np
print("Graph Tool backend not found")

try:
import spyrmsd.graphs.nx
graph_backends.append("networkx")

except ImportError:
print("NetworkX backend not found")

from spyrmsd import constants
if len(graph_backends) == 0:
sys.exit("No valid graph backends were found, please make sure atleast one of the supported backends is installed correctly")

def get_backends():
return graph_backends

def adjacency_matrix_from_atomic_coordinates(
aprops: np.ndarray, coordinates: np.ndarray
Expand Down
1 change: 1 addition & 0 deletions spyrmsd/graphs/nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,4 @@ def cycle(n):
Cycle graph
"""
return nx.cycle_graph(n)

43 changes: 39 additions & 4 deletions spyrmsd/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def strip(self) -> None:

self.stripped = True

def to_graph(self):
def to_graph(self, backend: str = "default"):
"""
Convert molecule to graph.

Expand All @@ -202,9 +202,44 @@ def to_graph(self):

The molecular graph is cached.
"""
if self.G is None:

## Check if the graph exists and is in the correct backend type
if self.G is None or not backend.lower() in self.current_backend:
# Check the backend
available_backends = graph.get_backends()

gt_backends = ["graphtool", "graph-tool", "graph_tool", "gt"]
nx_backends = ["networkx", "nx"]

if backend.lower() in gt_backends:
if "graph_tool" in available_backends:
import spyrmsd.graphs.gt as graph_backend
self.current_backend = gt_backends

else:
raise ImportError("Graph_tools backend not present")
elif backend.lower() in nx_backends:
if "networkx" in available_backends:
import spyrmsd.graphs.nx as graph_backend
self.current_backend = nx_backends
else:
raise ImportError("NetworkX backend not present")
elif backend.lower() == "default":

if len(available_backends) == 0:
raise ValueError("No valid backends were found, please ensure one of the supported backends is installed correctly")
if "graph_tool" in available_backends:
import spyrmsd.graphs.gt as graph_backend
self.current_backend = gt_backends

else:
import spyrmsd.graphs.nx as graph_backend
self.current_backend = nx_backends
else:
raise ValueError(f"Didn't recognize backend '{backend}'")
print(graph_backend)
try:
self.G = graph.graph_from_adjacency_matrix(
self.G = graph_backend.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)
except AttributeError:
Expand All @@ -218,7 +253,7 @@ def to_graph(self):
self.atomicnums, self.coordinates
)

self.G = graph.graph_from_adjacency_matrix(
self.G = graph_backend.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)

Expand Down
43 changes: 38 additions & 5 deletions spyrmsd/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _rmsd_isomorphic_core(
minimize: bool = False,
isomorphisms: Optional[List[Tuple[List[int], List[int]]]] = None,
atol: float = 1e-9,
backend: str = "default",
) -> Tuple[float, List[Tuple[List[int], List[int]]]]:
"""
Compute RMSD using graph isomorphism.
Expand Down Expand Up @@ -158,6 +159,29 @@ def _rmsd_isomorphic_core(
RMSD (after graph matching) and graph isomorphisms
"""

available_backends = graph.get_backends()

# Check the backend
if backend.lower() in ["graphtool", "graph-tool", "graph_tool", "graph tool", "gt"]:
if "graph_tool" in available_backends:
import spyrmsd.graphs.gt as graph_backend
else:
raise ImportError("Graph_tools backend not present")
elif backend.lower() in ["networkx", "nx"]:
if "networkx" in available_backends:
import spyrmsd.graphs.nx as graph_backend
else:
raise ImportError("NetworkX backend not present")
elif backend.lower() == "default":
if len(available_backends) == 0:
raise ValueError("No valid backends were found, please ensure one of the supported backends is installed correctly")
if "graph_tool" in available_backends:
import spyrmsd.graphs.gt as graph_backend
else:
import spyrmsd.graphs.nx as graph_backend
else:
raise ValueError(f"Didn't recognize backend '{backend}'")

assert coords1.shape == coords2.shape

n = coords1.shape[0]
Expand All @@ -169,11 +193,11 @@ def _rmsd_isomorphic_core(
# No cached isomorphisms
if isomorphisms is None:
# Convert molecules to graphs
G1 = graph.graph_from_adjacency_matrix(am1, aprops1)
G2 = graph.graph_from_adjacency_matrix(am2, aprops2)
G1 = graph_backend.graph_from_adjacency_matrix(am1, aprops1)
G2 = graph_backend.graph_from_adjacency_matrix(am2, aprops2)

# Get all the possible graph isomorphisms
isomorphisms = graph.match_graphs(G1, G2)
isomorphisms = graph_backend.match_graphs(G1, G2)

# Minimum result
# Squared displacement (not minimize) or RMSD (minimize)
Expand Down Expand Up @@ -214,6 +238,7 @@ def symmrmsd(
minimize: bool = False,
cache: bool = True,
atol: float = 1e-9,
backend: str = "default",
) -> Any:
"""
Compute RMSD using graph isomorphism for multiple coordinates.
Expand All @@ -240,7 +265,9 @@ def symmrmsd(
Cache graph isomorphisms
atol: float
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)

backend: str
Which backend to use (default, graph_tool or networkx)

Returns
-------
float: Union[float, List[float]]
Expand Down Expand Up @@ -276,6 +303,7 @@ def symmrmsd(
minimize=minimize,
isomorphisms=isomorphism,
atol=atol,
backend=backend,
)

RMSD.append(srmsd)
Expand All @@ -292,6 +320,7 @@ def symmrmsd(
minimize=minimize,
isomorphisms=None,
atol=atol,
backend=backend,
)

return RMSD
Expand All @@ -305,9 +334,10 @@ def rmsdwrapper(
minimize: bool = False,
strip: bool = True,
cache: bool = True,
backend: str = "default",
) -> Any:
"""
Compute RMSD between two molecule.
Compute RMSD between two molecules.

Parameters
----------
Expand All @@ -323,6 +353,8 @@ def rmsdwrapper(
Minimised RMSD (using the quaternion polynomial method)
strip: bool, optional
Strip hydrogen atoms
backend: str
Which backend to use (default, graph_tool or networkx)

Returns
-------
Expand Down Expand Up @@ -358,6 +390,7 @@ def rmsdwrapper(
center=center,
minimize=minimize,
cache=cache,
backend=backend,
)
else: # No symmetry
for c in cmols:
Expand Down
Loading
Loading