diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 7824e8b97..27ad35988 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -16,8 +16,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9] - torch: [1.12.0, 1.13.0, 2.0.0] + python-version: [3.8, 3.9, "3.10"] + torch: [1.13.0, 2.0.0, 2.1.0] #include: # - torch: 1.6.0 # torchvision: 0.7.0 @@ -60,8 +60,10 @@ jobs: # run: conda env create -n graphein-dev python=${{ matrix.python-version }} #- name: Activate Conda Environment # run: source activate graphein-dev + - name: Install Boost 1.7.3 (for DSSP) + run: conda install -c anaconda libboost=1.73.0 - name: Install DSSP - run: conda install -c salilab dssp + run: conda install dssp -c salilab - name: Install mmseqs run: mamba install -c conda-forge -c bioconda mmseqs2 - name: Install PyTorch diff --git a/.requirements/base.in b/.requirements/base.in index 24f311f31..739045494 100644 --- a/.requirements/base.in +++ b/.requirements/base.in @@ -4,6 +4,7 @@ biopython bioservices>=1.10.0 deepdiff loguru +looseversion matplotlib>=3.4.3 multipledispatch networkx diff --git a/CHANGELOG.md b/CHANGELOG.md index ffcc07476..bafeb190d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +### 1.7.4 - 24/10/2023 + +* Adds support for PyG 2.4+ ([#350](https://www.github.com/a-r-j/graphein/pull/339)) + ### 1.7.3 - 30/08/2023 * Fixes edge case in FoldComp database download if target directory has same name as database ([#339](https://github.com/a-r-j/graphein/pull/339)) diff --git a/graphein/protein/tensor/data.py b/graphein/protein/tensor/data.py index 3a2cd39f5..43d460fba 100644 --- a/graphein/protein/tensor/data.py +++ b/graphein/protein/tensor/data.py @@ -11,10 +11,12 @@ # Code Repository: https://github.com/a-r-j/graphein from typing import Any, Callable, List, Optional, Tuple, Union +import looseversion import pandas as pd import plotly.graph_objects as go import torch import torch.nn.functional as F +import torch_geometric from biopandas.pdb import PandasPdb from loguru import logger as log from torch_geometric.data import Batch, Data @@ -64,6 +66,8 @@ TorsionTensor, ) +PYG_VERSION = looseversion.LooseVersion(torch_geometric.__version__) + class Protein(Data): """ "A data object describing a homogeneous graph. ``Protein`` inherits from @@ -249,7 +253,12 @@ def from_data(self, data: Data) -> "Protein": :return: ``Protein`` object containing the same keys and values :rtype: Protein """ - keys = data.keys + keys = ( + data.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else data.keys + ) + for key in keys: setattr(self, key, getattr(data, key)) return self @@ -271,7 +280,12 @@ def to_data(self) -> Data: :rtype: Data """ data = Data() - for i in self.keys: + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) + for i in keys: setattr(data, i, getattr(self, i)) return data @@ -732,7 +746,13 @@ def has_complete_backbone(self) -> bool: def __eq__(self, __o: object) -> bool: # sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif - for i in self.keys: + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) + + for i in keys: attr_self = getattr(self, i) attr_other = getattr(__o, i) @@ -760,9 +780,15 @@ def plot_distance_matrix( return plot_distance_matrix(x) def plot_dihedrals(self) -> go.Figure: + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) + dh = ( dihedrals(self.coords) - if "dihedrals" not in self.keys + if "dihedrals" not in keys else self.dihedrals ) return plot_dihedrals(dh) @@ -833,7 +859,12 @@ def __init__( def from_batch( self, batch: Batch, fill_value: float = 1e-5 ) -> "ProteinBatch": - for key in batch.keys: + keys = ( + batch.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else batch.keys + ) + for key in keys: setattr(self, key, getattr(batch, key)) if hasattr(batch, "_slice_dict"): @@ -930,7 +961,11 @@ def from_pdb_files( def to_batch(self) -> Batch: """Returns the ProteinBatch as a torch_geometric.data.Batch object.""" batch = Batch() - keys = self.keys + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) for key in keys: setattr(batch, key, getattr(self, key)) return batch @@ -1190,8 +1225,12 @@ def to_protein_list(self) -> List["Protein"]: proteins = [Protein() for _ in range(self.num_graphs)] # Iterate over attributes - for k in self.keys: - print(k) + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) + for k in keys: # Get attribute attr = getattr(self, k) # Skip ptr @@ -1218,7 +1257,12 @@ def to_protein_list(self) -> List["Protein"]: def __eq__(self, __o: object) -> bool: # sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif - for i in self.keys: + keys = ( + self.keys() + if PYG_VERSION >= looseversion.LooseVersion("2.4.0") + else self.keys + ) + for i in keys: attr_self = getattr(self, i) attr_other = getattr(__o, i) diff --git a/tests/protein/test_graphs.py b/tests/protein/test_graphs.py index 856ae9385..d77915f60 100644 --- a/tests/protein/test_graphs.py +++ b/tests/protein/test_graphs.py @@ -106,7 +106,7 @@ def test_construct_graph(): def test_construct_graph_with_dssp(): """Makes sure protein graphs can be constructed with dssp - Uses uses both a pdb code (6REW) and a local pdb file to do so. + Uses uses both a pdb code (6YC3) and a local pdb file to do so. """ dssp_config_functions = { "edge_construction_functions": [ @@ -129,10 +129,10 @@ def test_construct_graph_with_dssp(): dssp_prot_config = ProteinGraphConfig(**dssp_config_functions) g_pdb = construct_graph( - config=dssp_prot_config, pdb_code="6rew" - ) # should download 6rew.pdb to pdb_dir + config=dssp_prot_config, pdb_code="6yc3" + ) # should download 6yc3.pdb to pdb_dir - assert g_pdb.graph["pdb_code"] == "6rew" + assert g_pdb.graph["pdb_code"] == "6yc3" assert g_pdb.graph["path"] is None assert g_pdb.graph["name"] == g_pdb.graph["pdb_code"] assert len(g_pdb.graph["dssp_df"]) == 1365