Skip to content

Commit

Permalink
Add support for PyG 2.4+ (#350)
Browse files Browse the repository at this point in the history
* change dssp conda install to apt

* add looseversion dependency

* bump changelog

* use apt-get

* fix broken test from deprecated PDB

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make pdb code in test lowercase

* make pdb code in test lowercase

* pin biopython version <1.81 for dssp compatibility

* switch dssp back to apt-get

* pin biopython <=1.79 for dssp

* pin dssp version in CI

* switch dssp install back to salilab

* install specific boost version for DSSP

* add python 3.10 and torch 2.1.0 to CI matrix

* make 3.10 python version string

* unpin biopython version

---------

Co-authored-by: Arian Jamasb <arian.jamasb@roche.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 25, 2023
1 parent 9a55b58 commit 028f416
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 16 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
torch: [1.12.0, 1.13.0, 2.0.0]
python-version: [3.8, 3.9, "3.10"]
torch: [1.13.0, 2.0.0, 2.1.0]
#include:
# - torch: 1.6.0
# torchvision: 0.7.0
Expand Down Expand Up @@ -60,8 +60,10 @@ jobs:
# run: conda env create -n graphein-dev python=${{ matrix.python-version }}
#- name: Activate Conda Environment
# run: source activate graphein-dev
- name: Install Boost 1.7.3 (for DSSP)
run: conda install -c anaconda libboost=1.73.0
- name: Install DSSP
run: conda install -c salilab dssp
run: conda install dssp -c salilab
- name: Install mmseqs
run: mamba install -c conda-forge -c bioconda mmseqs2
- name: Install PyTorch
Expand Down
1 change: 1 addition & 0 deletions .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ biopython
bioservices>=1.10.0
deepdiff
loguru
looseversion
matplotlib>=3.4.3
multipledispatch
networkx
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### 1.7.4 - 24/10/2023

* Adds support for PyG 2.4+ ([#350](https://www.github.com/a-r-j/graphein/pull/339))

### 1.7.3 - 30/08/2023

* Fixes edge case in FoldComp database download if target directory has same name as database ([#339](https://github.com/a-r-j/graphein/pull/339))
Expand Down
62 changes: 53 additions & 9 deletions graphein/protein/tensor/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# Code Repository: https://github.com/a-r-j/graphein
from typing import Any, Callable, List, Optional, Tuple, Union

import looseversion
import pandas as pd
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
import torch_geometric
from biopandas.pdb import PandasPdb
from loguru import logger as log
from torch_geometric.data import Batch, Data
Expand Down Expand Up @@ -64,6 +66,8 @@
TorsionTensor,
)

PYG_VERSION = looseversion.LooseVersion(torch_geometric.__version__)


class Protein(Data):
""" "A data object describing a homogeneous graph. ``Protein`` inherits from
Expand Down Expand Up @@ -249,7 +253,12 @@ def from_data(self, data: Data) -> "Protein":
:return: ``Protein`` object containing the same keys and values
:rtype: Protein
"""
keys = data.keys
keys = (
data.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else data.keys
)

for key in keys:
setattr(self, key, getattr(data, key))
return self
Expand All @@ -271,7 +280,12 @@ def to_data(self) -> Data:
:rtype: Data
"""
data = Data()
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for i in keys:
setattr(data, i, getattr(self, i))
return data

Expand Down Expand Up @@ -732,7 +746,13 @@ def has_complete_backbone(self) -> bool:

def __eq__(self, __o: object) -> bool:
# sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)

for i in keys:
attr_self = getattr(self, i)
attr_other = getattr(__o, i)

Expand Down Expand Up @@ -760,9 +780,15 @@ def plot_distance_matrix(
return plot_distance_matrix(x)

def plot_dihedrals(self) -> go.Figure:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)

dh = (
dihedrals(self.coords)
if "dihedrals" not in self.keys
if "dihedrals" not in keys
else self.dihedrals
)
return plot_dihedrals(dh)
Expand Down Expand Up @@ -833,7 +859,12 @@ def __init__(
def from_batch(
self, batch: Batch, fill_value: float = 1e-5
) -> "ProteinBatch":
for key in batch.keys:
keys = (
batch.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else batch.keys
)
for key in keys:
setattr(self, key, getattr(batch, key))

if hasattr(batch, "_slice_dict"):
Expand Down Expand Up @@ -930,7 +961,11 @@ def from_pdb_files(
def to_batch(self) -> Batch:
"""Returns the ProteinBatch as a torch_geometric.data.Batch object."""
batch = Batch()
keys = self.keys
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for key in keys:
setattr(batch, key, getattr(self, key))
return batch
Expand Down Expand Up @@ -1190,8 +1225,12 @@ def to_protein_list(self) -> List["Protein"]:
proteins = [Protein() for _ in range(self.num_graphs)]

# Iterate over attributes
for k in self.keys:
print(k)
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for k in keys:
# Get attribute
attr = getattr(self, k)
# Skip ptr
Expand All @@ -1218,7 +1257,12 @@ def to_protein_list(self) -> List["Protein"]:

def __eq__(self, __o: object) -> bool:
# sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for i in keys:
attr_self = getattr(self, i)
attr_other = getattr(__o, i)

Expand Down
8 changes: 4 additions & 4 deletions tests/protein/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_construct_graph():
def test_construct_graph_with_dssp():
"""Makes sure protein graphs can be constructed with dssp
Uses uses both a pdb code (6REW) and a local pdb file to do so.
Uses uses both a pdb code (6YC3) and a local pdb file to do so.
"""
dssp_config_functions = {
"edge_construction_functions": [
Expand All @@ -129,10 +129,10 @@ def test_construct_graph_with_dssp():
dssp_prot_config = ProteinGraphConfig(**dssp_config_functions)

g_pdb = construct_graph(
config=dssp_prot_config, pdb_code="6rew"
) # should download 6rew.pdb to pdb_dir
config=dssp_prot_config, pdb_code="6yc3"
) # should download 6yc3.pdb to pdb_dir

assert g_pdb.graph["pdb_code"] == "6rew"
assert g_pdb.graph["pdb_code"] == "6yc3"
assert g_pdb.graph["path"] is None
assert g_pdb.graph["name"] == g_pdb.graph["pdb_code"]
assert len(g_pdb.graph["dssp_df"]) == 1365
Expand Down

0 comments on commit 028f416

Please sign in to comment.