Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to select backend #107

Merged
merged 21 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
34702f3
Update graph.py to support set_backend function
Jnelen Feb 29, 2024
42c3015
Apply private _available_backends suggestions from code review
Jnelen Mar 2, 2024
a79aa0d
Print warning when backend is already set + make sure we use _availab…
Jnelen Mar 2, 2024
8d906de
remove reliance on environment variables
Jnelen Mar 2, 2024
3f50f8f
make _validate_backend function
Jnelen Mar 3, 2024
bced3a3
Use precommit hooks
Jnelen Mar 3, 2024
806d5cb
Update ValueError message
Jnelen Mar 3, 2024
7dfac0e
Don't return the backend when setting it
Jnelen Mar 3, 2024
35623cb
Add dummy function to make mypy happy
Jnelen Mar 3, 2024
4697eb1
update initializing the backend
Jnelen Mar 4, 2024
5b44a54
Add caching support for multiple backends
Jnelen Mar 4, 2024
227d0ab
reset the m.G to be an empty dict instead of being None (test_graph_f…
Jnelen Mar 4, 2024
b3dc025
Add backend functions to base of spyrmsd
Jnelen Mar 5, 2024
293f4a0
Enable isort compatibility with black in pre-commit.yml
Jnelen Mar 6, 2024
5750ad6
Added (conditional) test_backend test
Jnelen Mar 6, 2024
af5c2a8
Add CI test environment with both graph-tool and networkx
Jnelen Mar 6, 2024
1045a91
Fix and cleanup CI tests
Jnelen Mar 6, 2024
41a8a4c
Add testing for the new graph caching with multiple backends
Jnelen Mar 6, 2024
228c88b
exclude Windows and the all backends option for CI (since Windows doe…
Jnelen Mar 6, 2024
72dc339
update changelog and CITATION.cff
Jnelen Mar 6, 2024
c8f20a7
update changelog
Jnelen Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
Expand Down
6 changes: 6 additions & 0 deletions spyrmsd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

from .due import Doi, due

# Make the backend related functions available from base spyrmsd
# Add noqa to avoid flake8 error
from .graph import _available_backends as available_backends # noqa: F401
from .graph import _get_backend as get_backend # noqa: F401
from .graph import _set_backend as set_backend # noqa: F401

__version__ = "0.7.0-dev"

# This will print latest Zenodo version
Expand Down
146 changes: 114 additions & 32 deletions spyrmsd/graph.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,124 @@
import warnings

import numpy as np

from spyrmsd import constants

_available_backends = []
_current_backend = None

## Backend aliases
_graph_tool_aliases = ["graph_tool", "graphtool", "graph-tool", "graph tool", "gt"]
_networkx_aliases = ["networkx", "nx"]

## Construct the alias dictionary
_alias_to_backend = {}
for alias in _graph_tool_aliases:
_alias_to_backend[alias.lower()] = "graph-tool"
for alias in _networkx_aliases:
_alias_to_backend[alias.lower()] = "networkx"


def _dummy(*args, **kwargs):
"""
Dummy function for backend not set.
"""
raise NotImplementedError("No backend is set.")


## Assigning the properties/methods associated with a backend to a temporary dummy function
cycle = _dummy
graph_from_adjacency_matrix = _dummy
lattice = _dummy
match_graphs = _dummy
num_edges = _dummy
num_vertices = _dummy
vertex_property = _dummy

try:
from spyrmsd.graphs.gt import cycle as gt_cycle
from spyrmsd.graphs.gt import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
graph_from_adjacency_matrix as gt_graph_from_adjacency_matrix,
)
from spyrmsd.graphs.gt import lattice as gt_lattice
from spyrmsd.graphs.gt import match_graphs as gt_match_graphs
from spyrmsd.graphs.gt import num_edges as gt_num_edges
from spyrmsd.graphs.gt import num_vertices as gt_num_vertices
from spyrmsd.graphs.gt import vertex_property as gt_vertex_property

_available_backends.append("graph-tool")
except ImportError:
try:
from spyrmsd.graphs.nx import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
)
except ImportError:
raise ImportError("graph_tool or NetworkX libraries not found.")

__all__ = [
"graph_from_adjacency_matrix",
"match_graphs",
"vertex_property",
"num_vertices",
"num_edges",
"lattice",
"cycle",
"adjacency_matrix_from_atomic_coordinates",
]
Jnelen marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn("The graph-tool backend does not seem to be installed.")

import numpy as np
try:
from spyrmsd.graphs.nx import cycle as nx_cycle
from spyrmsd.graphs.nx import (
graph_from_adjacency_matrix as nx_graph_from_adjacency_matrix,
)
from spyrmsd.graphs.nx import lattice as nx_lattice
from spyrmsd.graphs.nx import match_graphs as nx_match_graphs
from spyrmsd.graphs.nx import num_edges as nx_num_edges
from spyrmsd.graphs.nx import num_vertices as nx_num_vertices
from spyrmsd.graphs.nx import vertex_property as nx_vertex_property

from spyrmsd import constants
_available_backends.append("networkx")
except ImportError:
warnings.warn("The networkx backend does not seem to be installed.")


def _validate_backend(backend):
standardized_backend = _alias_to_backend.get(backend.lower())
if standardized_backend is None:
raise ValueError(f"The {backend} backend is not recognized or supported")
if standardized_backend not in _available_backends:
raise ImportError(f"The {backend} backend doesn't seem to be installed")
return standardized_backend


def _set_backend(backend):
global _current_backend
backend = _validate_backend(backend)

## Check if we actually need to switch backends
if backend == _current_backend:
warnings.warn(f"The backend is already {backend}.")
return

global cycle, graph_from_adjacency_matrix, lattice, match_graphs, num_edges, num_vertices, vertex_property
RMeli marked this conversation as resolved.
Show resolved Hide resolved

if backend == "graph-tool":
cycle = gt_cycle
graph_from_adjacency_matrix = gt_graph_from_adjacency_matrix
lattice = gt_lattice
match_graphs = gt_match_graphs
num_edges = gt_num_edges
num_vertices = gt_num_vertices
vertex_property = gt_vertex_property

elif backend == "networkx":
cycle = nx_cycle
graph_from_adjacency_matrix = nx_graph_from_adjacency_matrix
lattice = nx_lattice
match_graphs = nx_match_graphs
num_edges = nx_num_edges
num_vertices = nx_num_vertices
vertex_property = nx_vertex_property

_current_backend = backend


if len(_available_backends) == 0:
raise ImportError(
"No valid backends found. Please ensure that either graph-tool or NetworkX are installed."
)
else:
if _current_backend is None:
## Set the backend to the first available (preferred) backend
_set_backend(backend=_available_backends[0])


def _get_backend():
return _current_backend


def adjacency_matrix_from_atomic_coordinates(
Expand Down
18 changes: 10 additions & 8 deletions spyrmsd/molecule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self.adjacency_matrix: np.ndarray = np.asarray(adjacency_matrix, dtype=int)

# Molecular graph
self.G = None
self.G: Dict[str, object] = {}

self.masses: Optional[List[float]] = None

Expand Down Expand Up @@ -182,7 +182,7 @@ def strip(self) -> None:
self.adjacency_matrix = self.adjacency_matrix[np.ix_(idx, idx)]

# Reset molecular graph when stripping
self.G = None
self.G = {}

self.stripped = True

Expand All @@ -200,11 +200,13 @@ def to_graph(self):
If the molecule does not have an associated adjacency matrix, a simple
bond perception is used.

The molecular graph is cached.
The molecular graph is cached per backend.
"""
if self.G is None:
_current_backend = graph._current_backend

if _current_backend not in self.G.keys():
try:
self.G = graph.graph_from_adjacency_matrix(
self.G[_current_backend] = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)
except AttributeError:
Expand All @@ -218,11 +220,11 @@ def to_graph(self):
self.atomicnums, self.coordinates
)

self.G = graph.graph_from_adjacency_matrix(
self.G[_current_backend] = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)

return self.G
return self.G[_current_backend]

def __len__(self) -> int:
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

import spyrmsd
from spyrmsd import constants, graph, io, molecule
from spyrmsd.exceptions import NonIsomorphicGraphs
from spyrmsd.graphs import _common as gc
Expand Down Expand Up @@ -153,3 +154,29 @@ def test_build_graph_node_features_unsupported() -> None:

with pytest.raises(ValueError, match="Unsupported property type:"):
_ = graph.graph_from_adjacency_matrix(A, property)


@pytest.mark.skipif(
len(spyrmsd.available_backends) < 2,
reason="Not all of the required backends are installed",
)
def test_set_backend() -> None:
A = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 1]])

import graph_tool as gt
import networkx as nx
Jnelen marked this conversation as resolved.
Show resolved Hide resolved

spyrmsd.set_backend("networkx")
assert spyrmsd.get_backend() == "networkx"

Gnx = graph.graph_from_adjacency_matrix(A)
assert isinstance(Gnx, nx.Graph)

spyrmsd.set_backend("graph-tool")
assert spyrmsd.get_backend() == "graph-tool"

Ggt = graph.graph_from_adjacency_matrix(A)
assert isinstance(Ggt, gt.Graph)

with pytest.raises(ValueError, match="backend is not recognized or supported"):
spyrmsd.set_backend("unknown")
2 changes: 1 addition & 1 deletion tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_graph_from_atomic_coordinates_perception(
m = copy.deepcopy(mol)

delattr(m, "adjacency_matrix")
m.G = None
m.G = {}
RMeli marked this conversation as resolved.
Show resolved Hide resolved

with pytest.warns(UserWarning):
# Uses automatic bond perception
Expand Down
Loading