diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 74c963c..10caace 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -26,10 +26,12 @@ jobs: os: [macOS-latest, ubuntu-latest, windows-latest] python-version: ["3.9", "3.10", "3.11", "3.12"] chemlib: [obabel, rdkit] - graphlib: [nx, gt] + graphlib: [nx, gt, all] exclude: # graph-tools does not work on Windows - {os: "windows-latest", graphlib: "gt"} + - {os: "windows-latest", graphlib: "all"} + - {graphlib: "all", chemlib: "obabel"} include: - {os: "macOS-14", graphlib: "gt", chemlib: "obabel", python-version: "3.12"} - {os: "macOS-14", graphlib: "nx", chemlib: "rdkit", python-version: "3.12"} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 176a260..6da5c7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,7 @@ repos: rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black"] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.5.1 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bf0dc7..1894a41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,15 @@ ## Version X.Y.Z Date: XX/YY/ZZZZ -Contributors: @RMeli, @takluyver +Contributors: @RMeli, @takluyver, @Jnelen + +### Added + +* Functionality to manually select the backend [PR #107 | @Jnelen] + +### Changed + +* Molecular graphs are now cached per backend using a dictionary [PR #107 | @Jnelen] ### Fixed diff --git a/CITATION.cff b/CITATION.cff index d1d1a7c..1006b56 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,6 +4,9 @@ authors: - family-names: "Meli" given-names: "Rocco" orcid: "https://orcid.org/0000-0002-2845-3410" +- family-names: "Nelen" + given-names: "Jochem" + orcid: "https://orcid.org/0000-0002-9970-4950" title: "spyrmsd" version: 0.6.0 doi: 10.5281/zenodo.3631876 diff --git a/devtools/conda-envs/spyrmsd-test-rdkit-all.yaml b/devtools/conda-envs/spyrmsd-test-rdkit-all.yaml new file mode 100644 index 0000000..9212381 --- /dev/null +++ b/devtools/conda-envs/spyrmsd-test-rdkit-all.yaml @@ -0,0 +1,28 @@ +name: spyrmsd +channels: + - conda-forge + - rdkit +dependencies: + # Base + - python + - setuptools + + # Maths + - numpy + - scipy + - graph-tool + - networkx>=2 + + # Chemistry + - rdkit + + # Testing + - pytest + - pytest-cov + - pytest-benchmark + + # Dev + - mypy + - flake8 + - black + - codecov diff --git a/spyrmsd/__init__.py b/spyrmsd/__init__.py index b38f4f9..36cb386 100644 --- a/spyrmsd/__init__.py +++ b/spyrmsd/__init__.py @@ -4,6 +4,12 @@ from .due import Doi, due +# Make the backend related functions available from base spyrmsd +# Add noqa to avoid flake8 error +from .graph import _available_backends as available_backends # noqa: F401 +from .graph import _get_backend as get_backend # noqa: F401 +from .graph import _set_backend as set_backend # noqa: F401 + __version__ = "0.7.0-dev" # This will print latest Zenodo version diff --git a/spyrmsd/graph.py b/spyrmsd/graph.py index 0235a25..3d27a16 100644 --- a/spyrmsd/graph.py +++ b/spyrmsd/graph.py @@ -1,42 +1,124 @@ +import warnings + +import numpy as np + +from spyrmsd import constants + +_available_backends = [] +_current_backend = None + +## Backend aliases +_graph_tool_aliases = ["graph_tool", "graphtool", "graph-tool", "graph tool", "gt"] +_networkx_aliases = ["networkx", "nx"] + +## Construct the alias dictionary +_alias_to_backend = {} +for alias in _graph_tool_aliases: + _alias_to_backend[alias.lower()] = "graph-tool" +for alias in _networkx_aliases: + _alias_to_backend[alias.lower()] = "networkx" + + +def _dummy(*args, **kwargs): + """ + Dummy function for backend not set. + """ + raise NotImplementedError("No backend is set.") + + +## Assigning the properties/methods associated with a backend to a temporary dummy function +cycle = _dummy +graph_from_adjacency_matrix = _dummy +lattice = _dummy +match_graphs = _dummy +num_edges = _dummy +num_vertices = _dummy +vertex_property = _dummy + try: + from spyrmsd.graphs.gt import cycle as gt_cycle from spyrmsd.graphs.gt import ( - cycle, - graph_from_adjacency_matrix, - lattice, - match_graphs, - num_edges, - num_vertices, - vertex_property, + graph_from_adjacency_matrix as gt_graph_from_adjacency_matrix, ) + from spyrmsd.graphs.gt import lattice as gt_lattice + from spyrmsd.graphs.gt import match_graphs as gt_match_graphs + from spyrmsd.graphs.gt import num_edges as gt_num_edges + from spyrmsd.graphs.gt import num_vertices as gt_num_vertices + from spyrmsd.graphs.gt import vertex_property as gt_vertex_property + _available_backends.append("graph-tool") except ImportError: - try: - from spyrmsd.graphs.nx import ( - cycle, - graph_from_adjacency_matrix, - lattice, - match_graphs, - num_edges, - num_vertices, - vertex_property, - ) - except ImportError: - raise ImportError("graph_tool or NetworkX libraries not found.") - -__all__ = [ - "graph_from_adjacency_matrix", - "match_graphs", - "vertex_property", - "num_vertices", - "num_edges", - "lattice", - "cycle", - "adjacency_matrix_from_atomic_coordinates", -] + warnings.warn("The graph-tool backend does not seem to be installed.") -import numpy as np +try: + from spyrmsd.graphs.nx import cycle as nx_cycle + from spyrmsd.graphs.nx import ( + graph_from_adjacency_matrix as nx_graph_from_adjacency_matrix, + ) + from spyrmsd.graphs.nx import lattice as nx_lattice + from spyrmsd.graphs.nx import match_graphs as nx_match_graphs + from spyrmsd.graphs.nx import num_edges as nx_num_edges + from spyrmsd.graphs.nx import num_vertices as nx_num_vertices + from spyrmsd.graphs.nx import vertex_property as nx_vertex_property -from spyrmsd import constants + _available_backends.append("networkx") +except ImportError: + warnings.warn("The networkx backend does not seem to be installed.") + + +def _validate_backend(backend): + standardized_backend = _alias_to_backend.get(backend.lower()) + if standardized_backend is None: + raise ValueError(f"The {backend} backend is not recognized or supported") + if standardized_backend not in _available_backends: + raise ImportError(f"The {backend} backend doesn't seem to be installed") + return standardized_backend + + +def _set_backend(backend): + global _current_backend + backend = _validate_backend(backend) + + ## Check if we actually need to switch backends + if backend == _current_backend: + warnings.warn(f"The backend is already {backend}.") + return + + global cycle, graph_from_adjacency_matrix, lattice, match_graphs, num_edges, num_vertices, vertex_property + + if backend == "graph-tool": + cycle = gt_cycle + graph_from_adjacency_matrix = gt_graph_from_adjacency_matrix + lattice = gt_lattice + match_graphs = gt_match_graphs + num_edges = gt_num_edges + num_vertices = gt_num_vertices + vertex_property = gt_vertex_property + + elif backend == "networkx": + cycle = nx_cycle + graph_from_adjacency_matrix = nx_graph_from_adjacency_matrix + lattice = nx_lattice + match_graphs = nx_match_graphs + num_edges = nx_num_edges + num_vertices = nx_num_vertices + vertex_property = nx_vertex_property + + _current_backend = backend + + +if len(_available_backends) == 0: + raise ImportError( + "No valid backends found. Please ensure that either graph-tool or NetworkX are installed." + ) +else: + if _current_backend is None: + ## Set the backend to the first available (preferred) backend + _set_backend(backend=_available_backends[0]) + + +def _get_backend(): + return _current_backend def adjacency_matrix_from_atomic_coordinates( diff --git a/spyrmsd/molecule.py b/spyrmsd/molecule.py index 3d593a0..b2ab87c 100644 --- a/spyrmsd/molecule.py +++ b/spyrmsd/molecule.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -50,7 +50,7 @@ def __init__( self.adjacency_matrix: np.ndarray = np.asarray(adjacency_matrix, dtype=int) # Molecular graph - self.G = None + self.G: Dict[str, object] = {} self.masses: Optional[List[float]] = None @@ -182,7 +182,7 @@ def strip(self) -> None: self.adjacency_matrix = self.adjacency_matrix[np.ix_(idx, idx)] # Reset molecular graph when stripping - self.G = None + self.G = {} self.stripped = True @@ -200,11 +200,13 @@ def to_graph(self): If the molecule does not have an associated adjacency matrix, a simple bond perception is used. - The molecular graph is cached. + The molecular graph is cached per backend. """ - if self.G is None: + _current_backend = graph._current_backend + + if _current_backend not in self.G.keys(): try: - self.G = graph.graph_from_adjacency_matrix( + self.G[_current_backend] = graph.graph_from_adjacency_matrix( self.adjacency_matrix, self.atomicnums ) except AttributeError: @@ -218,11 +220,11 @@ def to_graph(self): self.atomicnums, self.coordinates ) - self.G = graph.graph_from_adjacency_matrix( + self.G[_current_backend] = graph.graph_from_adjacency_matrix( self.adjacency_matrix, self.atomicnums ) - return self.G + return self.G[_current_backend] def __len__(self) -> int: """ diff --git a/tests/test_graph.py b/tests/test_graph.py index 029e604..88c9643 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,6 +1,7 @@ import numpy as np import pytest +import spyrmsd from spyrmsd import constants, graph, io, molecule from spyrmsd.exceptions import NonIsomorphicGraphs from spyrmsd.graphs import _common as gc @@ -153,3 +154,29 @@ def test_build_graph_node_features_unsupported() -> None: with pytest.raises(ValueError, match="Unsupported property type:"): _ = graph.graph_from_adjacency_matrix(A, property) + + +@pytest.mark.skipif( + len(spyrmsd.available_backends) < 2, + reason="Not all of the required backends are installed", +) +def test_set_backend() -> None: + import graph_tool as gt + import networkx as nx + + A = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 1]]) + + spyrmsd.set_backend("networkx") + assert spyrmsd.get_backend() == "networkx" + + Gnx = graph.graph_from_adjacency_matrix(A) + assert isinstance(Gnx, nx.Graph) + + spyrmsd.set_backend("graph-tool") + assert spyrmsd.get_backend() == "graph-tool" + + Ggt = graph.graph_from_adjacency_matrix(A) + assert isinstance(Ggt, gt.Graph) + + with pytest.raises(ValueError, match="backend is not recognized or supported"): + spyrmsd.set_backend("unknown") diff --git a/tests/test_molecule.py b/tests/test_molecule.py index 43efda4..450f453 100644 --- a/tests/test_molecule.py +++ b/tests/test_molecule.py @@ -6,6 +6,7 @@ import numpy as np import pytest +import spyrmsd from spyrmsd import constants, graph, io, molecule, utils from tests import molecules @@ -167,7 +168,7 @@ def test_graph_from_atomic_coordinates_perception( m = copy.deepcopy(mol) delattr(m, "adjacency_matrix") - m.G = None + m.G = {} with pytest.warns(UserWarning): # Uses automatic bond perception @@ -236,3 +237,35 @@ def test_from_rdmol(adjacency): with pytest.raises(AttributeError): # No adjacency_matrix attribute mol.adjacency_matrix + + +@pytest.mark.skipif( + len(spyrmsd.available_backends) < 2, + reason="Not all of the required backends are installed", +) +@pytest.mark.parametrize( + "mol", [(molecules.benzene), (molecules.ethanol), (molecules.dialanine)] +) +def test_molecule_graph_cache(mol) -> None: + import graph_tool as gt + import networkx as nx + + ## Graph cache persists from previous tests, manually reset them + mol.G = {} + spyrmsd.set_backend("networkx") + mol.to_graph() + + assert isinstance(mol.G["networkx"], nx.Graph) + assert "graph-tool" not in mol.G.keys() + + spyrmsd.set_backend("graph-tool") + mol.to_graph() + + ## Make sure both backends (still) have a cache + assert isinstance(mol.G["networkx"], nx.Graph) + assert isinstance(mol.G["graph-tool"], gt.Graph) + + ## Strip the molecule to ensure the cache is reset + mol.strip() + + assert len(mol.G.items()) == 0