diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1f815a9c..28c149dc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -67,4 +67,4 @@ jobs: - name: Run unit tests and generate coverage report run: pytest . - name: Test notebook execution - run: pytest --nbval-lax notebooks/ --current-env + run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 8af8f134..1a1af77a 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -10,6 +10,7 @@ Summaries :maxdepth: 2 :glob: + notebooks/dataloader_tutorial.nblink datasets/pscdb notebooks/pscdb_processing.nblink notebooks/pscdb_baselines.nblink diff --git a/docs/source/notebooks/dataloader_tutorial.nblink b/docs/source/notebooks/dataloader_tutorial.nblink new file mode 100644 index 00000000..a15eb374 --- /dev/null +++ b/docs/source/notebooks/dataloader_tutorial.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../notebooks/dataloader_tutorial.ipynb" +} \ No newline at end of file diff --git a/graphein/grn/features/node_features.py b/graphein/grn/features/node_features.py index 1991d737..32764442 100644 --- a/graphein/grn/features/node_features.py +++ b/graphein/grn/features/node_features.py @@ -1,14 +1,19 @@ +import logging + from graphein.utils.utils import import_message +log = logging.getLogger(__name__) + try: from bioservices import HGNC, UniProt except ImportError: - import_message( + message = import_message( submodule="graphein.grn.features.node_features", package="bioservices", conda_channel="bioconda", pip_install=True, ) + log.warning(message) def add_sequence_to_nodes(n, d): diff --git a/graphein/ml/__init__.py b/graphein/ml/__init__.py index e69de29b..2474af4e 100644 --- a/graphein/ml/__init__.py +++ b/graphein/ml/__init__.py @@ -0,0 +1,10 @@ +from .conversion import GraphFormatConvertor + +try: + from .datasets import ( + InMemoryProteinGraphDataset, + ProteinGraphDataset, + ProteinGraphListDataset, + ) +except (ImportError, NameError): + pass diff --git a/graphein/ml/datasets/__init__.py b/graphein/ml/datasets/__init__.py new file mode 100644 index 00000000..c76b1fc6 --- /dev/null +++ b/graphein/ml/datasets/__init__.py @@ -0,0 +1,5 @@ +from .torch_geometric_dataset import ( + InMemoryProteinGraphDataset, + ProteinGraphDataset, + ProteinGraphListDataset, +) diff --git a/graphein/ml/datasets/torch_geometric_dataset.py b/graphein/ml/datasets/torch_geometric_dataset.py new file mode 100644 index 00000000..8a62b6ab --- /dev/null +++ b/graphein/ml/datasets/torch_geometric_dataset.py @@ -0,0 +1,547 @@ +"""Pytorch Geometric Dataset classes for Protein Graphs.""" +# Graphein +# Author: Arian Jamasb +# License: MIT +# Project Website: https://github.com/a-r-j/graphein +# Code Repository: https://github.com/a-r-j/graphein +from __future__ import annotations + +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional + +import networkx as nx +from tqdm import tqdm + +from graphein.ml.conversion import GraphFormatConvertor +from graphein.protein.config import ProteinGraphConfig +from graphein.protein.graphs import construct_graphs_mp +from graphein.protein.utils import download_alphafold_structure, download_pdb +from graphein.utils.utils import import_message + +try: + import torch + from torch_geometric.data import Data, Dataset, InMemoryDataset +except ImportError: + import_message( + "graphein.ml.datasets.torch_geometric_dataset", + "torch_geometric", + conda_channel="pyg", + pip_install=True, + ) + + +class InMemoryProteinGraphDataset(InMemoryDataset): + def __init__( + self, + root: str, + name: str, + pdb_codes: Optional[List[str]] = None, + uniprot_ids: Optional[List[str]] = None, + graph_label_map: Optional[Dict[str, torch.Tensor]] = None, + node_label_map: Optional[Dict[str, torch.Tensor]] = None, + chain_selection_map: Optional[Dict[str, List[str]]] = None, + graphein_config: ProteinGraphConfig = ProteinGraphConfig(), + graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( + src_format="nx", dst_format="pyg" + ), + graph_transformation_funcs: Optional[List[Callable]] = None, + transform: Optional[Callable] = None, + pdb_transform: Optional[List[Callable]] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + num_cores: int = 16, + af_version: int = 2, + ): + """In Memory dataset for protein graphs. + + Dataset base class for creating graph datasets which easily fit + into CPU memory. Inherits from + :class:`torch_geometric.data.InMemoryDataset`, which inherits from + :class:`torch_geometric.data.Dataset`. + See `here `__ for the accompanying + tutorial. + + :param root: Root directory where the dataset should be saved. + :type root: str + :param name: Name of the dataset. Will be saved to ``data_$name.pt``. + :type name: str + :param pdb_codes: List of PDB codes to download and parse from the PDB. + Defaults to None. + :type pdb_codes: Optional[List[str]], optional + :param uniprot_ids: List of Uniprot IDs to download and parse from + Alphafold Database. Defaults to ``None``. + :type uniprot_ids: Optional[List[str]], optional + :param graph_label_map: Dictionary mapping PDB/Uniprot IDs to + graph-level labels. Defaults to ``None``. + :type graph_label_map: Optional[Dict[str, Tensor]], optional + :param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level + labels. Defaults to ``None``. + :type node_label_map: Optional[Dict[str, torch.Tensor]], optional + :param chain_selection_map: Dictionary mapping, defaults to ``None``. + :type chain_selection_map: Optional[Dict[str, List[str]]], optional + :param graphein_config: Protein graph construction config, defaults to + ``ProteinGraphConfig()``. + :type graphein_config: ProteinGraphConfig, optional + :param graph_format_convertor: Conversion handler for graphs, defaults + to ``GraphFormatConvertor(src_format="nx", dst_format="pyg")``. + :type graph_format_convertor: GraphFormatConvertor, optional + :param pdb_transform: List of functions that consume a list of paths to + the downloaded structures. This provides an entry point to apply + pre-processing from bioinformatics tools of your choosing. Defaults + to ``None``. + :type pdb_transform: Optional[List[Callable]], optional + :param graph_transformation_funcs: List of functions that consume a + ``nx.Graph`` and return a ``nx.Graph``. Applied to graphs after + construction but before conversion to pyg. Defaults to ``None``. + :type graph_transformation_funcs: Optional[List[Callable]], optional + :param transform: A function/transform that takes in a + ``torch_geometric.data.Data`` object and returns a transformed + version. The data object will be transformed before every access. + Defaults to ``None``. + :type transform: Optional[Callable], optional + :param pre_transform: A function/transform that takes in an + ``torch_geometric.data.Data`` object and returns a transformed + version. The data object will be transformed before being saved to + disk. Defaults to ``None``. + :type pre_transform: Optional[Callable], optional + :param pre_filter: A function that takes in a + ``torch_geometric.data.Data`` object and returns a boolean value, + indicating whether the data object should be included in the final + dataset. Optional, defaults to ``None``. + :type pre_filter: Optional[Callable], optional + :param num_cores: Number of cores to use for multiprocessing of graph + construction, defaults to ``16``. + :type num_cores: int, optional + :param af_version: Version of AlphaFoldDB structures to use, + defaults to ``2``. + :type af_version: int, optional + """ + self.name = name + self.pdb_codes = ( + [pdb.lower() for pdb in pdb_codes] + if pdb_codes is not None + else None + ) + self.uniprot_ids = ( + [up.upper() for up in uniprot_ids] + if uniprot_ids is not None + else None + ) + + if self.pdb_codes and self.uniprot_ids: + self.structures = self.pdb_codes + self.uniprot_ids + elif self.pdb_codes: + self.structures = pdb_codes + elif self.uniprot_ids: + self.structures = uniprot_ids + self.af_version = af_version + + # Labels & Chains + self.graph_label_map = graph_label_map + self.node_label_map = node_label_map + self.chain_selection_map = chain_selection_map + + # Configs + self.config = graphein_config + self.graph_format_convertor = graph_format_convertor + self.graph_transformation_funcs = graph_transformation_funcs + self.pdb_transform = pdb_transform + self.num_cores = num_cores + super().__init__( + root, + transform=transform, + pre_transform=pre_transform, + pre_filter=pre_filter, + ) + self.config.pdb_dir = Path(self.raw_dir) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + """Name of the raw files in the dataset.""" + return [f"{pdb}.pdb" for pdb in self.structures] + + @property + def processed_file_names(self) -> List[str]: + """Name of the processed file.""" + return [f"data_{self.name}.pt"] + + def download(self): + """Download the PDB files from RCSB or Alphafold.""" + self.config.pdb_dir = Path(self.raw_dir) + if self.pdb_codes: + [download_pdb(self.config, pdb) for pdb in tqdm(self.pdb_codes)] + if self.uniprot_ids: + [ + download_alphafold_structure( + uniprot, + out_dir=self.raw_dir, + version=self.af_version, + aligned_score=False, + ) + for uniprot in tqdm(self.uniprot_ids) + ] + + def __len__(self) -> int: + return len(self.structures) + + def transform_pdbs(self): + """ + Performs pre-processing of PDB structures before constructing graphs. + """ + structure_files = [ + f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures + ] + for func in self.pdb_transform: + func(structure_files) + + def process(self): + """Process structures into PyG format and save to disk.""" + # Read data into huge `Data` list. + structure_files = [ + f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures + ] + + # Apply transformations to raw PDB files. + if self.pdb_transform is not None: + self.transform_pdbs() + + if self.chain_selection_map: + chain_selections = [ + self.chain_selection_map[pdb] + if pdb in self.chain_selection_map.keys() + else "all" + for pdb in self.structures + ] + else: + chain_selections = None + + # Create graph objects + graphs = construct_graphs_mp( + pdb_path_it=structure_files, + config=self.config, + chain_selections=chain_selections, + return_dict=True, + num_cores=self.num_cores, + ) + # Transform graphs + if self.graph_transformation_funcs is not None: + for func in self.graph_transformation_funcs: + graphs = {k: func(v) for k, v in graphs.items()} + + # Convert to PyTorch Geometric Data + graphs = {k: self.graph_format_convertor(v) for k, v in graphs.items()} + graphs = dict(zip(self.structures, graphs.values())) + + # Assign labels + if self.graph_label_map: + for k, v in self.graph_label_map.items(): + graphs[k].graph_y = v + if self.node_label_map: + for k, v in self.node_label_map.items(): + graphs[k].node_y = v + + data_list = list(graphs.values()) + del graphs + + if self.pre_filter is not None: + data_list = [g for g in data_list if self.pre_filter(g)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + + +class ProteinGraphDataset(Dataset): + def __init__( + self, + root, + pdb_codes: Optional[List[str]] = None, + uniprot_ids: Optional[List[str]] = None, + graph_label_map: Optional[Dict[str, int]] = None, + node_label_map: Optional[Dict[str, int]] = None, + chain_selection_map: Optional[Dict[str, List[str]]] = None, + graphein_config: ProteinGraphConfig = ProteinGraphConfig(), + graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( + src_format="nx", dst_format="pyg" + ), + graph_transformation_funcs: Optional[List[Callable]] = None, + pdb_transform: Optional[List[Callable]] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + num_cores: int = 16, + af_version: int = 2, + ): + """Dataset class for protein graphs. + + Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + :param root: Root directory where the dataset should be saved. + :type root: str + :param pdb_codes: List of PDB codes to download and parse from the PDB. + Defaults to ``None``. + :type pdb_codes: Optional[List[str]], optional + :param uniprot_ids: List of Uniprot IDs to download and parse from + Alphafold Database. Defaults to ``None``. + :type uniprot_ids: Optional[List[str]], optional + :param graph_label_map: Dictionary mapping PDB/Uniprot IDs to + graph-level labels. Defaults to ``None``. + :type graph_label_map: Optional[Dict[str, Tensor]], optional + :param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level + labels. Defaults to ``None``. + :type node_label_map: Optional[Dict[str, torch.Tensor]], optional + :param chain_selection_map: Dictionary mapping, defaults to ``None``. + :type chain_selection_map: Optional[Dict[str, List[str]]], optional + :param graphein_config: Protein graph construction config, defaults to + ``ProteinGraphConfig()``. + :type graphein_config: ProteinGraphConfig, optional + :param graph_format_convertor: Conversion handler for graphs, defaults + to ``GraphFormatConvertor(src_format="nx", dst_format="pyg")``. + :type graph_format_convertor: GraphFormatConvertor, optional + :param graph_transformation_funcs: List of functions that consume a + ``nx.Graph`` and return a ``nx.Graph``. Applied to graphs after + construction but before conversion to pyg. Defaults to ``None``. + :type graph_transformation_funcs: Optional[List[Callable]], optional + :param pdb_transform: List of functions that consume a list of paths to + the downloaded structures. This provides an entry point to apply + pre-processing from bioinformatics tools of your choosing. Defaults + to ``None``. + :type pdb_transform: Optional[List[Callable]], optional + :param transform: A function/transform that takes in a + ``torch_geometric.data.Data`` object and returns a transformed + version. The data object will be transformed before every access. + Defaults to ``None``. + :type transform: Optional[Callable], optional + :param pre_transform: A function/transform that takes in an + ``torch_geometric.data.Data`` object and returns a transformed + version. The data object will be transformed before being saved to + disk. Defaults to ``None``. + :type pre_transform: Optional[Callable], optional + :param pre_filter: A function that takes in a + ``torch_geometric.data.Data`` object and returns a boolean value, + indicating whether the data object should be included in the final + dataset. Optional, defaults to ``None``. + :type pre_filter: Optional[Callable], optional + :param num_cores: Number of cores to use for multiprocessing of graph + construction, defaults to ``16``. + :type num_cores: int, optional + :param af_version: Version of AlphaFoldDB structures to use, + defaults to ``2``. + :type af_version: int, optional + """ + self.pdb_codes = ( + [pdb.lower() for pdb in pdb_codes] + if pdb_codes is not None + else None + ) + self.uniprot_ids = ( + [up.upper() for up in uniprot_ids] + if uniprot_ids is not None + else None + ) + + if self.pdb_codes and self.uniprot_ids: + self.structures = self.pdb_codes + self.uniprot_ids + elif self.pdb_codes: + self.structures = pdb_codes + elif self.uniprot_ids: + self.structures = uniprot_ids + self.af_version = af_version + + # Labels & Chains + self.graph_label_map = graph_label_map + self.node_label_map = node_label_map + self.chain_selection_map = chain_selection_map + + # Configs + self.config = graphein_config + self.graph_format_convertor = graph_format_convertor + self.num_cores = num_cores + self.pdb_transform = pdb_transform + self.graph_transformation_funcs = graph_transformation_funcs + super().__init__( + root, + transform=transform, + pre_transform=pre_transform, + pre_filter=pre_filter, + ) + self.config.pdb_dir = Path(self.raw_dir) + + @property + def raw_file_names(self) -> List[str]: + """Names of raw files in the dataset.""" + return [f"{pdb}.pdb" for pdb in self.structures] + + @property + def processed_file_names(self) -> List[str]: + """Names of processed files to look for""" + return [f"{pdb}.pt" for pdb in self.structures] + + def download(self): + """Download the PDB files from RCSB or Alphafold.""" + self.config.pdb_dir = Path(self.raw_dir) + if self.pdb_codes: + [download_pdb(self.config, pdb) for pdb in tqdm(self.pdb_codes)] + if self.uniprot_ids: + [ + download_alphafold_structure( + uniprot, + out_dir=self.raw_dir, + version=self.af_version, + aligned_score=False, + ) + for uniprot in tqdm(self.uniprot_ids) + ] + + def len(self) -> int: + """Returns length of data set (number of structures).""" + return len(self.structures) + + def transform_pdbs(self): + """ + Performs pre-processing of PDB structures before constructing graphs. + """ + structure_files = [ + f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures + ] + for func in self.pdb_transform: + func(structure_files) + + def transform_graphein_graphs(self, graph: nx.Graph): + for func in self.graph_transformation_funcs: + graph = func(graph) + return graph + + def process(self): + """Processes structures from files into PyTorch Geometric Data.""" + # Preprocess PDB files + if self.pdb_transform: + self.transform_pdbs() + + idx = 0 + # Chunk dataset for parallel processing + chunk_size = 128 + + def divide_chunks(l: List[str], n: int = 2) -> List[List[str]]: + for i in range(0, len(l), n): + yield l[i : i + n] + + chunks = list(divide_chunks(self.structures, chunk_size)) + + for chunk in tqdm(chunks): + # Get chain selections + if self.chain_selection_map: + chain_selections = [ + self.chain_selection_map[pdb] + if pdb in self.chain_selection_map.keys() + else "all" + for pdb in self.structures + ] + else: + chain_selections = None + + # Create graph objects + file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in chunk] + graphs = construct_graphs_mp( + pdb_path_it=file_names, + config=self.config, + chain_selections=chain_selections, + return_dict=True, + ) + if self.graph_transformation_funcs is not None: + graphs = { + k: self.transform_graphein_graphs(v) + for k, v in graphs.items() + } + # Convert to PyTorch Geometric Data + graphs = { + k: self.graph_format_convertor(v) for k, v in graphs.items() + } + graphs = dict(zip(chunk, graphs.values())) + + # Assign labels + if self.graph_label_map: + for k, v in self.graph_label_map.items(): + graphs[k].graph_y = v + if self.node_label_map: + for k, v in self.node_label_map.items(): + graphs[k].node_y = v + + data_list = list(graphs.values()) + + del graphs + + if self.pre_filter is not None: + data_list = [g for g in data_list if self.pre_filter(g)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + idxs = [ + i + for i in range(idx * chunk_size, idx * chunk_size + len(chunk)) + ] + + for data, id in zip(data_list, idxs): + + torch.save( + data, + os.path.join( + self.processed_dir, f"{self.structures[id]}.pt" + ), + ) + idx += 1 + + def get(self, idx: int): + """ + Returns PyTorch Geometric Data object for a given index. + + :param idx: Index to retrieve. + :type idx: int + :return: PyTorch Geometric Data object. + """ + return torch.load( + os.path.join(self.processed_dir, f"{self.structures[idx]}.pt") + ) + + +class ProteinGraphListDataset(InMemoryDataset): + def __init__( + self, root: str, data_list: List[Data], name: str, transform=None + ): + """Creates a dataset from a list of PyTorch Geometric Data objects. + + :param root: Root directory where the dataset is stored. + :type root: str + :param data_list: List of protein graphs as PyTorch Geometric Data + objects. + :type data_list: List[Data] + :param name: Name of dataset. Data will be saved as ``data_{name}.pt``. + :type name: str + :param transform: A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + :type transform: Optional[Callable], optional + """ + self.data_list = data_list + self.name = name + super().__init__(root, transform) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_file_names(self): + """The name of the files in the :obj:`self.processed_dir` folder that + must be present in order to skip processing.""" + return f"data_{self.name}.pt" + + def process(self): + """Saves data files to disk.""" + torch.save(self.collate(self.data_list), self.processed_paths[0]) diff --git a/graphein/ppi/features/node_features.py b/graphein/ppi/features/node_features.py index 7d7dde6a..5162c93a 100644 --- a/graphein/ppi/features/node_features.py +++ b/graphein/ppi/features/node_features.py @@ -1,4 +1,6 @@ """Functions for adding nodes features to a PPI Graph""" +import logging + # %% # Graphein # Author: Ramon Vinas, Arian Jamasb @@ -11,15 +13,19 @@ from graphein.utils.utils import import_message +log = logging.getLogger(__name__) + + try: from bioservices import HGNC, UniProt except ImportError: - import_message( + message = import_message( submodule="graphein.ppi.features.nodes_features", package="bioservices", conda_channel="bioconda", pip_install=True, ) + log.warning(message) def add_sequence_to_nodes(n: str, d: Dict[str, Any]): diff --git a/graphein/protein/features/nodes/aaindex.py b/graphein/protein/features/nodes/aaindex.py index f21c5ee8..8e8e9aae 100644 --- a/graphein/protein/features/nodes/aaindex.py +++ b/graphein/protein/features/nodes/aaindex.py @@ -1,17 +1,22 @@ +import logging from typing import Dict, Tuple import networkx as nx from graphein.utils.utils import import_message, protein_letters_3to1_all_caps +log = logging.getLogger(__name__) + + try: from pyaaisc import Aaindex except ImportError: - import_message( + message = import_message( submodule="graphein.protein.features.nodes.aaindex", package="pyaaisc", pip_install=True, ) + log.warning(message) def fetch_AAIndex(accession: str) -> Tuple[str, Dict[str, float]]: diff --git a/graphein/protein/features/sequence/embeddings.py b/graphein/protein/features/sequence/embeddings.py index 02c1a60e..2b30da9f 100644 --- a/graphein/protein/features/sequence/embeddings.py +++ b/graphein/protein/features/sequence/embeddings.py @@ -6,6 +6,7 @@ # Code Repository: https://github.com/a-r-j/graphein from __future__ import annotations +import logging import os from functools import lru_cache, partial from pathlib import Path @@ -18,24 +19,29 @@ ) from graphein.utils.utils import import_message +log = logging.getLogger(__name__) + + try: import torch except ImportError: - import_message( + message = import_message( submodule="graphein.protein.features.sequence.embeddings", package="torch", pip_install=True, conda_channel="pytorch", ) + log.warning(message) try: import biovec except ImportError: - import_message( + message = import_message( submodule="graphein.protein.features.sequence.embeddings", package="biovec", pip_install=True, ) + log.warning(message) @lru_cache() diff --git a/graphein/protein/graphs.py b/graphein/protein/graphs.py index f8818305..1ec1d910 100644 --- a/graphein/protein/graphs.py +++ b/graphein/protein/graphs.py @@ -8,6 +8,9 @@ from __future__ import annotations import logging +import multiprocessing +import traceback +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import networkx as nx @@ -617,6 +620,99 @@ def construct_graph( return g +def _mp_graph_constructor( + args: Tuple[str, str], use_pdb_code: bool, config: ProteinGraphConfig +) -> nx.Graph: + """ + Protein graph constructor for use in multiprocessing several protein structure graphs. + + :param args: Tuple of pdb code/path and the chain selection for that PDB + :type args: Tuple[str, str] + :param use_pdb_code: Whether or not we are using pdb codes or paths + :type use_pdb_code: bool + :param config: Protein structure graph construction config + :type config: ProteinGraphConfig + :return: Protein structure graph + :rtype: nx.Graph + """ + log.info(f"Constructing graph for: {args[0]}. Chain selection: {args[1]}") + func = partial(construct_graph, config=config) + try: + return ( + func(pdb_code=args[0], chain_selection=args[1]) + if use_pdb_code + else func(pdb_path=args[0], chain_selection=args[1]) + ) + + except Exception as ex: + log.info( + f"Graph construction error (PDB={args[0]})! {traceback.format_exc()}" + ) + log.info(ex) + return None + + +def construct_graphs_mp( + pdb_code_it: Optional[List[str]] = None, + pdb_path_it: Optional[List[str]] = None, + chain_selections: Optional[list[str]] = None, + config: ProteinGraphConfig = ProteinGraphConfig(), + num_cores: int = 16, + return_dict: bool = True, +) -> Union[List[nx.Graph], Dict[str, nx.Graph]]: + """ + Constructs protein graphs for a list of pdb codes or pdb paths using multiprocessing. + + :param pdb_code_it: List of pdb codes to use for protein graph construction + :type pdb_code_it: Optional[List[str]], defaults to None + :param pdb_path_it: List of paths to PDB files to use for protein graph construction + :type pdb_path_it: Optional[List[str]], defaults to None + :param chain_selections: List of chains to select from the protein structures (e.g. ["ABC", "A", "L", "CD"...]) + :type chain_selections: Optional[List[str]], defaults to None + :param config: ProteinGraphConfig to use. + :type config: graphein.protein.config.ProteinGraphConfig, defaults to default config params + :param num_cores: Number of cores to use for multiprocessing. The more the merrier + :type num_cores: int, defaults to 16 + :param return_dict: Whether or not to return a dictionary (indexed by pdb codes/paths) or a list of graphs. + :type return_dict: bool, default to True + :return: Iterable of protein graphs. None values indicate there was a problem in constructing the graph for this particular pdb + :rtype: Union[List[nx.Graph], Dict[str, nx.Graph]] + """ + assert ( + pdb_code_it is not None or pdb_path_it is not None + ), "Iterable of pdb codes OR pdb paths required." + + if pdb_code_it is not None: + pdbs = pdb_code_it + use_pdb_code = True + + if pdb_path_it is not None: + pdbs = pdb_path_it + use_pdb_code = False + + if chain_selections is None: + chain_selections = ["all"] * len(pdbs) + + constructor = partial( + _mp_graph_constructor, use_pdb_code=use_pdb_code, config=config + ) + + pool = multiprocessing.Pool(num_cores) + graphs = list( + pool.map( + constructor, + [(pdb, chain_selections[i]) for i, pdb in enumerate(pdbs)], + ) + ) + pool.close() + pool.join() + + if return_dict: + graphs = {pdb: graphs[i] for i, pdb in enumerate(pdbs)} + + return graphs + + def compute_chain_graph( g: nx.Graph, chain_list: Optional[List[str]] = None, @@ -684,7 +780,6 @@ def compute_chain_graph( # Compute a weighted graph if required. if return_weighted_graph: return compute_weighted_graph_from_multigraph(h) - return h diff --git a/graphein/protein/meshes.py b/graphein/protein/meshes.py index 0b499dd2..b89f596e 100644 --- a/graphein/protein/meshes.py +++ b/graphein/protein/meshes.py @@ -16,18 +16,19 @@ from graphein.utils.pymol import MolViewer from graphein.utils.utils import import_message +log = logging.getLogger(__name__) + + try: from pytorch3d.structures import Meshes except ImportError: - import_message( + message = import_message( submodule="graphein.protein.meshes", package="pytorch3d", conda_channel="pytorch3d", pip_install=True, ) - - -log = logging.getLogger(__name__) + log.warning(message) def check_for_pymol_installation(): diff --git a/graphein/protein/utils.py b/graphein/protein/utils.py index 10a9a7dc..49339f33 100644 --- a/graphein/protein/utils.py +++ b/graphein/protein/utils.py @@ -65,6 +65,7 @@ def download_pdb(config, pdb_code: str) -> Path: :return: returns filepath to downloaded structure. :rtype: str """ + pdb_code = pdb_code.lower() if not config.pdb_dir: config.pdb_dir = Path("/tmp/") @@ -74,7 +75,7 @@ def download_pdb(config, pdb_code: str) -> Path: pdb_code, pdir=config.pdb_dir, overwrite=True, file_format="pdb" ) # If file not downloaded, check for obsolescence - if not os.path.exists(config.pdb_dir / f"{pdb_code}.pdb"): + if not os.path.exists(config.pdb_dir / f"pdb{pdb_code}.ent"): obs_map = get_obsolete_mapping() try: new_pdb = obs_map[pdb_code.lower()].lower() @@ -164,43 +165,67 @@ def compute_rgroup_dataframe(pdb_df: pd.DataFrame) -> pd.DataFrame: def download_alphafold_structure( uniprot_id: str, + version: int = 2, out_dir: str = ".", + rename: bool = True, pdb: bool = True, mmcif: bool = False, aligned_score: bool = True, ) -> Union[str, Tuple[str, str]]: - BASE_URL = "https://alphafold.ebi.ac.uk/files/" """ Downloads a structure from the Alphafold EBI database (https://alphafold.ebi.ac.uk/files/"). :param uniprot_id: UniProt ID of desired protein. :type uniprot_id: str - :param out_dir: String specifying desired output location. Default is pwd. + :param version: Version of the structure to download + :type version: int + :param out_dir: string specifying desired output location. Default is pwd. :type out_dir: str - :param mmcif: Bool specifying whether to download ``MMCiF`` or ``PDB``. Default is ``False`` (downloads pdb). + :param rename: boolean specifying whether to rename the output file to ``$uniprot_id.pdb``. Default is ``True``. + :type rename: bool + :param pdb: boolean specifying whether to download the PDB file. Default is ``True``. + :type pdb: bool + :param mmcif: Bool specifying whether to download MMCiF or PDB. Default is false (downloads pdb) :type mmcif: bool :param retrieve_aligned_score: Bool specifying whether or not to download score alignment json. :type retrieve_aligned_score: bool :return: path to output. Tuple if several outputs specified. :rtype: Union[str, Tuple[str, str]] """ + BASE_URL = "https://alphafold.ebi.ac.uk/files/" + uniprot_id = uniprot_id.upper() + if not mmcif and not pdb: raise ValueError("Must specify either mmcif or pdb.") if mmcif: - query_url = f"{BASE_URL}AF-{uniprot_id}F1-model_v1.cif" + query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.cif" if pdb: - query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v1.pdb" - + query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb" structure_filename = wget.download(query_url, out=out_dir) + if rename: + extension = ".pdb" if pdb else ".cif" + os.rename( + structure_filename, Path(out_dir) / f"{uniprot_id}{extension}" + ) + structure_filename = str( + (Path(out_dir) / f"{uniprot_id}{extension}").resolve() + ) + + log.info(f"Downloaded AlphaFold PDB file for: {uniprot_id}") if aligned_score: score_query = ( BASE_URL + "AF-" + uniprot_id - + "-F1-predicted_aligned_error_v1.json" + + f"-F1-predicted_aligned_error_v{version}.json" ) score_filename = wget.download(score_query, out=out_dir) + if rename: + os.rename(score_filename, Path(out_dir) / f"{uniprot_id}.json") + score_filename = str( + (Path(out_dir) / f"{uniprot_id}.json").resolve() + ) return structure_filename, score_filename return structure_filename diff --git a/graphein/protein/visualisation.py b/graphein/protein/visualisation.py index 8f4a21b6..3cfbd71d 100644 --- a/graphein/protein/visualisation.py +++ b/graphein/protein/visualisation.py @@ -22,14 +22,19 @@ from graphein.protein.subgraphs import extract_k_hop_subgraph from graphein.utils.utils import import_message +log = logging.getLogger(__name__) + + try: from pytorch3d.ops import sample_points_from_meshes except ImportError: - import_message( + message = import_message( submodule="graphein.protein.visualisation", package="pytorch3d", conda_channel="pytorch3d", ) + log.warning(message) + try: from mpl_chord_diagram import chord_diagram except ImportError: @@ -39,8 +44,6 @@ pip_install=True, ) -log = logging.getLogger() - def plot_pointcloud(mesh: Meshes, title: str = "") -> Axes3D: """ diff --git a/graphein/testing/utils.py b/graphein/testing/utils.py index b413c658..1596f963 100644 --- a/graphein/testing/utils.py +++ b/graphein/testing/utils.py @@ -7,12 +7,23 @@ # Code Repository: https://github.com/a-r-j/graphein import logging as log -from ast import Call +from functools import partial from typing import Any, Callable, Dict import networkx as nx import numpy as np +from graphein.utils.utils import import_message + +__all__ = [ + "compare_exact", + "compare_approximate", + "graphs_isomorphic", + "nodes_equal", + "edges_equal", + "edge_data_equal", +] + def compare_exact(first: Dict[str, Any], second: Dict[str, Any]) -> bool: """Return whether two dicts of arrays are exactly equal. diff --git a/graphein/utils/utils.py b/graphein/utils/utils.py index 2c02b304..dd46e4c3 100644 --- a/graphein/utils/utils.py +++ b/graphein/utils/utils.py @@ -1,4 +1,4 @@ -"""Utilities for working with graph objects""" +"""Utilities for working with graph objects.""" # Graphein # Author: Arian Jamasb , Eric Ma # License: MIT @@ -328,7 +328,7 @@ def import_message( package: str, conda_channel: Optional[str] = None, pip_install: bool = False, -): +) -> str: """ Return warning if package is not found. Generic message for indicating to the user when a function relies on an @@ -358,17 +358,12 @@ def import_message( installable = False installation = f"{package} cannot be installed via pip" - print( - f"To use the Graphein submodule {submodule}, you need to install " - f"{package}." - ) - print() + message = f"To use the Graphein submodule {submodule}, you need to install: {package} " if installable: - print("To do so, use the following command:") - print() - print(f" {installation}") + message += f"\nTo do so, use the following command: {installation}" else: - print(f"{installation}") + message += f"\n{installation}" + return message def ping(host: str) -> bool: diff --git a/notebooks/dataloader_tutorial.ipynb b/notebooks/dataloader_tutorial.ipynb new file mode 100644 index 00000000..57d68cd4 --- /dev/null +++ b/notebooks/dataloader_tutorial.ipynb @@ -0,0 +1,676 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Graphein Protein Structure Dataloaders\n", + "## PyTorch Geometric Datasets\n", + "\n", + "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset)\n", + "\n", + "Graphein provides three dataset classes for working with [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/):\n", + "\n", + "* [`ProteinGraphDataset`]() - For processing large datasets that can't be kept in memory\n", + "* [`InMemoryProteinGraphDataset`]() - For smaller datasets that can be kept in memory\n", + "* [`ProteinGraphListDataset`]() - For creating a dataset from a list of pre-computed PyTorch Geometric graphs.\n", + "\n", + "Both `ProteinGraphDataset` and `InMemoryGraphDataset` will take care of downloading structures from either the [RCSB PDB](https://www.rcsb.org/), [EBI AlphaFold database](https://alphafold.com/), or both!\n", + "`ProteinGraphListDataset` is a lightweight alternative for creating a dataset from a collection of graphs you have pre-computed.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/dataloader_tutorial.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Install graphein if necessary\n", + "# !pip install graphein\n", + "\n", + "# Install torch if necessary. See https://pytorch.org/get-started/locally/\n", + "# pip install torch==1.11.0\n", + "\n", + "# Install torch geometric if necessary. See: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html\n", + "# pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ProteinGraphDataset\n", + "\n", + "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.ProteinGraphDataset)\n", + "\n", + "`ProteinGraphDataset` will download structures from the PDB/AlphafoldDB, process the structures into graphs according to a `ProteinGraphConfig`.\n", + "\n", + "#### Parameters\n", + "```python\n", + "ProteinGraphDataset(\n", + " root: str, \n", + " # Root directory where the dataset should be saved.\n", + " name: str, \n", + " # Name of the dataset. Will be saved to ``data_$name.pt``.\n", + " pdb_codes: Optional[List[str]] = None, \n", + " # List of PDB codes to download and parse from the PDB.\n", + " uniprot_ids: Optional[List[str]] = None, \n", + " # List of Uniprot IDs to download and parse from Alphafold Database\n", + " graph_label_map: Optional[Dict[str, torch.Tensor]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to graph-level labels.\n", + " node_label_map: Optional[Dict[str, torch.Tensor]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to node-level labels.\n", + " chain_selection_map: Optional[Dict[str, List[str]]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to the desired chains in the PDB files\n", + " graphein_config: ProteinGraphConfig = ProteinGraphConfig(), \n", + " # Protein graph construction config\n", + " graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( \n", + " src_format=\"nx\", dst_format=\"pyg\"\n", + " ),\n", + " # Conversion handler for graphs\n", + " graph_transformation_funcs: Optional[List[Callable]] = None, \n", + " # List of functions that consume a nx.Graph and return a nx.Graph. Applied to graphs after construction but before conversion to pyg\n", + " transform: Optional[Callable] = None, \n", + " # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n", + " pdb_transform: Optional[List[Callable]] = None,\n", + " pre_transform: Optional[Callable] = None, \n", + " # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk\n", + " pre_filter: Optional[Callable] = None, \n", + " # A function that takes in a torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset\n", + " num_cores: int = 16, \n", + " # Number of cores to use for multiprocessing of graph construction\n", + " af_version: int = 2, \n", + " # Version of AlphaFoldDB structures to use,\n", + " )\n", + "```\n", + "\n", + "\n", + "#### Directory Structure\n", + "Creating a ``ProteinGraphDataset`` will create two directories under ``root``:\n", + "\n", + "* ``root/raw`` - Contains raw PDB files\n", + "* ``root/processed`` - Contains processed graphs (in ``pytorch_geometric.data.Data`` format) saved as ``$PDB.pt / $UNIPROT_ID.pt``" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from graphein.ml import ProteinGraphDataset\n", + "import graphein.protein as gp\n", + "\n", + "# Create some labels\n", + "g_labels = torch.randn([5])\n", + "n_labels = torch.randn([5, 10])\n", + "\n", + "g_lab_map = {\"3eiy\": g_labels[0], \"4hhb\": g_labels[1], \"Q5VSL9\": g_labels[2], \"1lds\": g_labels[3], \"Q8W3K0\": g_labels[4]}\n", + "node_lab_map = {\"3eiy\": n_labels[0], \"4hhb\": n_labels[1], \"Q5VSL9\": n_labels[2], \"1lds\": n_labels[3], \"Q8W3K0\": n_labels[4]}\n", + "\n", + "# Select some chains\n", + "chain_selection_map = {\"4hhb\": \"A\"}\n", + "\n", + "\n", + "# Create the dataset\n", + "ds = ProteinGraphDataset(\n", + " root = \"../graphein/ml/datasets/test\",\n", + " pdb_codes=[\"3eiy\", \"4hhb\", \"1lds\"],\n", + " uniprot_ids=[\"Q5VSL9\", \"Q8W3K0\"],\n", + " graph_label_map=g_lab_map,\n", + " node_label_map=node_lab_map,\n", + " chain_selection_map=chain_selection_map,\n", + " graphein_config=gp.ProteinGraphConfig()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataBatch(edge_index=[2, 236], node_id=[2], coords=[2], name=[2], dist_mat=[2], num_nodes=238, graph_y=[2], node_y=[20], batch=[238], ptr=[3])\n", + "Graph labels: tensor([ 0.5660, -0.7161])\n", + "Node labels: tensor([-1.2430, 0.8221, -0.0296, -0.3522, 1.7685, -2.3006, -0.1209, -1.4377,\n", + " -1.2816, -0.7039, -0.8580, -0.5647, -1.6848, -1.5069, -2.8355, -0.4000,\n", + " 0.3203, 0.1497, -1.0708, 0.3418])\n" + ] + } + ], + "source": [ + "# Create a dataloader from dataset and inspect a batch\n", + "from torch_geometric.loader import DataLoader\n", + "\n", + "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n", + "for i in dl:\n", + " print(i)\n", + " print(\"Graph labels: \", i.graph_y)\n", + " print(\"Node labels: \", i.node_y)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### InMemoryProteinGraphDataset\n", + "\n", + "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.InMemoryProteinGraphDataset)\n", + "\n", + "#### Parameters\n", + "```python\n", + "InMemoryProteinGraphDataset(\n", + " root: str, \n", + " # Root directory where the dataset should be saved.\n", + " name: str, \n", + " # Name of the dataset. Will be saved to ``data_$name.pt``.\n", + " pdb_codes: Optional[List[str]] = None, \n", + " # List of PDB codes to download and parse from the PDB.\n", + " uniprot_ids: Optional[List[str]] = None, \n", + " # List of Uniprot IDs to download and parse from Alphafold Database\n", + " graph_label_map: Optional[Dict[str, torch.Tensor]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to graph-level labels.\n", + " node_label_map: Optional[Dict[str, torch.Tensor]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to node-level labels.\n", + " chain_selection_map: Optional[Dict[str, List[str]]] = None, \n", + " # Dictionary mapping PDB/Uniprot IDs to the desired chains in the PDB files\n", + " graphein_config: ProteinGraphConfig = ProteinGraphConfig(), \n", + " # Protein graph construction config\n", + " graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( \n", + " src_format=\"nx\", dst_format=\"pyg\"\n", + " ),\n", + " # Conversion handler for graphs\n", + " graph_transformation_funcs: Optional[List[Callable]] = None, \n", + " # List of functions that consume a nx.Graph and return a nx.Graph. Applied to graphs after construction but before conversion to pyg\n", + " transform: Optional[Callable] = None, \n", + " # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n", + " pdb_transform: Optional[List[Callable]] = None,\n", + " pre_transform: Optional[Callable] = None, \n", + " # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk\n", + " pre_filter: Optional[Callable] = None, \n", + " # A function that takes in a torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset\n", + " num_cores: int = 16, \n", + " # Number of cores to use for multiprocessing of graph construction\n", + " af_version: int = 2, \n", + " # Version of AlphaFoldDB structures to use,\n", + " )\n", + "```\n", + "\n", + "#### Directory Structure\n", + "Creating an ``InMemoryProteinGraphDataset`` will create two directories under ``root``:\n", + "* ``root/raw`` - Contains raw PDB files\n", + "* ``root/processed`` - Contains processed datasets saved as ``data_{name}.pt``" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/4hhb.pdb. Chain selection: A\n", + "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/Q5VSL9.pdb. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/1lds.pdb. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/2ll6.pdb. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/3eiy.pdb. Chain selection: all\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 97 total nodes\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "DEBUG:graphein.protein.graphs:Detected 174 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 141 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 837 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 165 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "Done!\n" + ] + } + ], + "source": [ + "from graphein.ml import InMemoryProteinGraphDataset\n", + "\n", + "g_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n", + "node_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n", + "chain_selection_map = {\"4hhb\": \"A\"}\n", + "\n", + "ds = InMemoryProteinGraphDataset(\n", + " root = \"../graphein/ml/datasets/test\",\n", + " name=\"test\",\n", + " pdb_codes=[\"3eiy\", \"4hhb\", \"1lds\", \"2ll6\"],\n", + " uniprot_ids=[\"Q5VSL9\"],\n", + " graph_label_map=g_lab_map,\n", + " node_label_map=node_lab_map,\n", + " chain_selection_map=chain_selection_map\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataBatch(edge_index=[2, 236], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=238, batch=[238], ptr=[3])\n" + ] + } + ], + "source": [ + "# Create a dataloader from dataset and inspect a batch\n", + "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n", + "for i in dl:\n", + " print(i)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ProteinGraphListDataset\n", + "\n", + "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.ProteinGraphListDataset)\n", + "\n", + "The `ProteinGraphListDataset` class is a lightweight class for wrapping a list of pre-computed `pytorch_geometric.data.Data` graphs.\n", + "\n", + "#### Parameters\n", + "\n", + "```python\n", + "ProteinGraphListDataset(\n", + " root: str, # Root directory where the dataset is stored.\n", + " data_list: List[Data], # List of protein graphs as PyTorch Geometric Data objects.\n", + " name: str, # Name of dataset. Data will be saved as ``data_{name}.pt``.\n", + " transform: Optional[Callable]=None # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n", + " )\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + "INFO:graphein.protein.graphs:Constructing graph for: 4hhb. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: 3eiy. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: 1lds. Chain selection: all\n", + "INFO:graphein.protein.graphs:Constructing graph for: 2ll6. Chain selection: all\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 174 total nodes\n", + "DEBUG:graphein.protein.graphs:Detected 97 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "174\n", + "97\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 574 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "574\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n", + "DEBUG:graphein.protein.graphs:Detected 165 total nodes\n", + "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "165\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:graphein.protein.subgraphs:Found 174 nodes in the chain subgraph.\n", + "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:VAL:107', 'A:SER:26', 'A:PHE:45', 'A:PRO:53', 'A:VAL:29', 'A:ALA:144', 'A:LEU:39', 'A:GLN:14', 'A:TYR:56', 'A:PRO:69', 'A:MET:117', 'A:ASP:98', 'A:SER:123', 'A:VAL:72', 'A:PRO:60', 'A:ILE:135', 'A:THR:48', 'A:ALA:33', 'A:PHE:78', 'A:ALA:108', 'A:ASN:55', 'A:ARG:87', 'A:GLU:21', 'A:THR:167', 'A:GLU:27', 'A:MET:96', 'A:ALA:119', 'A:VAL:74', 'A:TYR:31', 'A:LEU:63', 'A:VAL:114', 'A:ALA:36', 'A:ASP:168', 'A:GLY:83', 'A:PHE:173', 'A:GLN:25', 'A:TRP:156', 'A:PHE:16', 'A:ASP:71', 'A:LEU:94', 'A:ASP:112', 'A:LEU:12', 'A:GLY:65', 'A:LEU:81', 'A:PRO:23', 'A:VAL:102', 'A:ALA:8', 'A:LYS:136', 'A:GLU:140', 'A:ILE:22', 'A:LYS:35', 'A:ALA:82', 'A:GLU:165', 'A:PHE:139', 'A:LEU:80', 'A:LYS:152', 'A:GLY:38', 'A:GLN:61', 'A:SER:2', 'A:GLY:169', 'A:LEU:106', 'A:ASP:157', 'A:ASN:172', 'A:LYS:10', 'A:VAL:18', 'A:ARG:44', 'A:GLU:99', 'A:ASP:125', 'A:ILE:19', 'A:ILE:159', 'A:GLY:49', 'A:ARG:89', 'A:ILE:59', 'A:TYR:52', 'A:ALA:90', 'A:LYS:95', 'A:VAL:170', 'A:LEU:91', 'A:ASP:68', 'A:GLU:146', 'A:LYS:143', 'A:LYS:132', 'A:PRO:28', 'A:SER:84', 'A:LEU:40', 'A:PRO:77', 'A:GLY:92', 'A:GLY:67', 'A:LYS:30', 'A:LEU:121', 'A:VAL:41', 'A:ILE:124', 'A:VAL:54', 'A:ILE:166', 'A:ALA:171', 'A:GLY:101', 'A:GLY:155', 'A:VAL:85', 'A:LYS:174', 'A:ASN:17', 'A:ILE:20', 'A:ALA:162', 'A:ASP:15', 'A:VAL:86', 'A:MET:93', 'A:LYS:175', 'A:VAL:70', 'A:HIS:163', 'A:LYS:149', 'A:LEU:73', 'A:ALA:161', 'A:LEU:37', 'A:GLY:158', 'A:SER:64', 'A:PRO:128', 'A:ARG:51', 'A:ALA:24', 'A:LYS:105', 'A:LYS:147', 'A:SER:100', 'A:GLN:134', 'A:ILE:75', 'A:ASP:103', 'A:GLY:57', 'A:LEU:145', 'A:LYS:164', 'A:LYS:122', 'A:PHE:58', 'A:GLU:154', 'A:THR:97', 'A:ASP:126', 'A:LEU:131', 'A:ALA:129', 'A:ALA:104', 'A:VAL:109', 'A:THR:118', 'A:LYS:113', 'A:VAL:127', 'A:GLY:148', 'A:TYR:130', 'A:GLY:47', 'A:VAL:153', 'A:PRO:116', 'A:GLU:32', 'A:ASP:43', 'A:CYS:115', 'A:ASN:5', 'A:ASP:34', 'A:ASP:160', 'A:THR:62', 'A:ASN:120', 'A:THR:76', 'A:HIS:111', 'A:ASP:11', 'A:VAL:6', 'A:ASP:133', 'A:ASP:66', 'A:TYR:142', 'A:PRO:7', 'A:PRO:13', 'A:HIS:137', 'A:SER:4', 'A:PHE:3', 'A:MET:50', 'A:VAL:151', 'A:ILE:46', 'A:PRO:79', 'A:PHE:138', 'A:VAL:42', 'A:ALA:88', 'A:GLY:9', 'A:TRP:150', 'A:GLN:141', 'A:PRO:110'].\n", + "DEBUG:graphein.protein.subgraphs:Found 141 nodes in the chain subgraph.\n", + "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:VAL:107', 'A:LEU:83', 'A:GLY:22', 'A:GLY:25', 'A:ALA:26', 'A:ALA:123', 'A:LEU:2', 'A:LEU:105', 'A:ALA:130', 'A:LYS:7', 'A:ALA:79', 'A:TYR:24', 'A:ASN:78', 'A:SER:131', 'A:TYR:42', 'A:LEU:100', 'A:LEU:101', 'A:SER:102', 'A:LYS:11', 'A:ALA:69', 'A:SER:35', 'A:HIS:50', 'A:HIS:58', 'A:TYR:140', 'A:HIS:20', 'A:ALA:71', 'A:LEU:136', 'A:PHE:43', 'A:PHE:46', 'A:LEU:66', 'A:VAL:121', 'A:MET:76', 'A:GLU:27', 'A:ALA:120', 'A:HIS:89', 'A:VAL:10', 'A:VAL:93', 'A:ARG:141', 'A:SER:3', 'A:SER:133', 'A:ASP:64', 'A:GLY:51', 'A:GLU:23', 'A:SER:81', 'A:GLN:54', 'A:PRO:95', 'A:THR:38', 'A:HIS:87', 'A:LYS:99', 'A:LYS:90', 'A:ASN:68', 'A:ALA:82', 'A:THR:39', 'A:LYS:139', 'A:THR:108', 'A:HIS:45', 'A:ASP:75', 'A:LEU:80', 'A:SER:124', 'A:VAL:17', 'A:LEU:86', 'A:ALA:13', 'A:LYS:127', 'A:ASP:85', 'A:THR:67', 'A:LEU:106', 'A:LYS:61', 'A:ALA:63', 'A:ASP:47', 'A:ALA:111', 'A:ALA:21', 'A:ALA:12', 'A:LEU:91', 'A:SER:138', 'A:GLU:116', 'A:LEU:48', 'A:GLU:30', 'A:SER:52', 'A:VAL:62', 'A:SER:84', 'A:PRO:77', 'A:GLY:59', 'A:PHE:98', 'A:ALA:19', 'A:ASN:9', 'A:HIS:103', 'A:ASP:6', 'A:ARG:92', 'A:LYS:60', 'A:GLY:18', 'A:PHE:36', 'A:PRO:44', 'A:PRO:4', 'A:ALA:28', 'A:LYS:40', 'A:VAL:96', 'A:THR:134', 'A:HIS:122', 'A:VAL:70', 'A:SER:49', 'A:PRO:119', 'A:THR:137', 'A:LYS:16', 'A:ASN:97', 'A:ARG:31', 'A:VAL:1', 'A:ALA:53', 'A:TRP:14', 'A:ALA:5', 'A:ALA:115', 'A:LEU:34', 'A:GLY:57', 'A:HIS:112', 'A:ALA:65', 'A:ASP:126', 'A:LEU:125', 'A:PRO:37', 'A:HIS:72', 'A:THR:118', 'A:CYS:104', 'A:ASP:94', 'A:THR:8', 'A:PHE:33', 'A:VAL:135', 'A:LYS:56', 'A:LEU:29', 'A:PRO:114', 'A:ASP:74', 'A:LEU:109', 'A:LEU:113', 'A:VAL:132', 'A:GLY:15', 'A:MET:32', 'A:VAL:55', 'A:PHE:128', 'A:ALA:88', 'A:ALA:110', 'A:PHE:117', 'A:VAL:73', 'A:THR:41', 'A:LEU:129'].\n", + "DEBUG:graphein.protein.subgraphs:Found 97 nodes in the chain subgraph.\n", + "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:GLU:50', 'A:LEU:64', 'A:PHE:30', 'A:THR:4', 'A:ARG:81', 'A:LEU:39', 'A:ASN:83', 'A:ALA:15', 'A:VAL:9', 'A:ALA:79', 'A:PHE:22', 'A:ASP:96', 'A:ARG:45', 'A:TYR:78', 'A:SER:55', 'A:THR:86', 'A:ARG:12', 'A:LYS:48', 'A:SER:28', 'A:GLY:29', 'A:VAL:93', 'A:PRO:32', 'A:GLU:16', 'A:ILE:35', 'A:SER:57', 'A:SER:33', 'A:ASP:53', 'A:ILE:7', 'A:LYS:19', 'A:LEU:65', 'A:GLU:77', 'A:TRP:60', 'A:SER:20', 'A:TYR:63', 'A:PHE:70', 'A:TRP:95', 'A:HIS:13', 'A:ASP:38', 'A:LEU:54', 'A:PRO:5', 'A:TYR:10', 'A:LEU:87', 'A:THR:68', 'A:HIS:31', 'A:TYR:66', 'A:SER:11', 'A:CYS:25', 'A:SER:52', 'A:VAL:82', 'A:LYS:91', 'A:LEU:23', 'A:LEU:40', 'A:ARG:3', 'A:VAL:27', 'A:GLU:69', 'A:ASP:76', 'A:LYS:75', 'A:THR:73', 'A:PRO:90', 'A:GLY:43', 'A:PHE:56', 'A:VAL:85', 'A:GLY:18', 'A:GLU:44', 'A:ASN:17', 'A:VAL:37', 'A:SER:88', 'A:ASP:59', 'A:GLN:8', 'A:ASN:24', 'A:HIS:84', 'A:CYS:80', 'A:ILE:1', 'A:GLU:36', 'A:HIS:51', 'A:LYS:41', 'A:LYS:6', 'A:VAL:49', 'A:GLU:47', 'A:TYR:67', 'A:PHE:62', 'A:PRO:72', 'A:ASP:34', 'A:GLN:2', 'A:ASN:21', 'A:ILE:92', 'A:SER:61', 'A:MET:0', 'A:PRO:14', 'A:GLN:89', 'A:ILE:46', 'A:GLU:74', 'A:THR:71', 'A:LYS:94', 'A:LYS:58', 'A:ASN:42', 'A:TYR:26'].\n", + "DEBUG:graphein.protein.subgraphs:Found 148 nodes in the chain subgraph.\n", + "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:THR:117', 'A:LEU:18', 'A:ALA:1', 'A:GLY:25', 'A:PHE:68', 'A:ASN:111', 'A:LEU:39', 'A:LEU:105', 'A:ALA:15', 'A:THR:26', 'A:ASP:129', 'A:VAL:91', 'A:ALA:147', 'A:GLU:82', 'A:PRO:43', 'A:ASN:53', 'A:MET:124', 'A:ALA:10', 'A:ASP:24', 'A:LYS:13', 'A:ALA:46', 'A:SER:38', 'A:GLN:143', 'A:THR:110', 'A:VAL:121', 'A:MET:76', 'A:ARG:90', 'A:ASP:2', 'A:LYS:115', 'A:GLU:45', 'A:ARG:74', 'A:PHE:16', 'A:GLY:33', 'A:THR:5', 'A:ASP:64', 'A:TYR:99', 'A:PHE:92', 'A:GLU:123', 'A:SER:81', 'A:GLN:49', 'A:GLU:140', 'A:MET:51', 'A:MET:36', 'A:ALA:103', 'A:SER:17', 'A:VAL:136', 'A:THR:29', 'A:ASP:56', 'A:ILE:27', 'A:ARG:86', 'A:ASP:122', 'A:PHE:12', 'A:THR:70', 'A:GLU:120', 'A:LEU:48', 'A:ILE:9', 'A:ALA:128', 'A:LEU:32', 'A:LYS:77', 'A:GLU:139', 'A:GLY:59', 'A:GLU:31', 'A:GLU:127', 'A:GLU:7', 'A:GLY:40', 'A:LYS:75', 'A:LYS:30', 'A:GLU:6', 'A:MET:145', 'A:ALA:57', 'A:ARG:106', 'A:ALA:73', 'A:PHE:141', 'A:GLY:98', 'A:GLY:113', 'A:THR:79', 'A:GLN:135', 'A:ASN:137', 'A:GLU:104', 'A:TYR:138', 'A:ARG:126', 'A:ASN:60', 'A:ASP:78', 'A:GLN:8', 'A:ASN:97', 'A:MET:144', 'A:GLU:114', 'A:GLU:84', 'A:VAL:35', 'A:ASP:131', 'A:PRO:66', 'A:ASP:58', 'A:MET:71', 'A:GLN:41', 'A:ASP:80', 'A:LEU:69', 'A:VAL:142', 'A:GLY:134', 'A:SER:101', 'A:ASP:20', 'A:GLN:3', 'A:THR:34', 'A:LYS:148', 'A:LEU:112', 'A:ASP:22', 'A:PHE:65', 'A:GLU:67', 'A:THR:146', 'A:MET:109', 'A:ILE:130', 'A:THR:44', 'A:MET:72', 'A:GLY:23', 'A:VAL:108', 'A:ILE:63', 'A:LYS:21', 'A:GLU:119', 'A:GLU:47', 'A:ILE:100', 'A:ARG:37', 'A:ASP:118', 'A:PHE:19', 'A:GLY:61', 'A:THR:62', 'A:GLY:132', 'A:LEU:4', 'A:GLU:87', 'A:ALA:102', 'A:ASP:133', 'A:HIS:107', 'A:ASP:95', 'A:GLY:96', 'A:GLU:54', 'A:GLU:83', 'A:ASP:50', 'A:THR:28', 'A:VAL:55', 'A:PHE:89', 'A:ASP:93', 'A:ILE:85', 'A:GLU:14', 'A:ALA:88', 'A:LYS:94', 'A:LEU:116', 'A:ILE:52', 'A:GLU:11', 'A:ASN:42', 'A:ILE:125'].\n", + "Processing...\n", + "Done!\n" + ] + } + ], + "source": [ + "from graphein.ml import ProteinGraphListDataset, GraphFormatConvertor\n", + "import graphein.protein as gp\n", + "\n", + "# Construct graphs\n", + "graphs = gp.construct_graphs_mp(\n", + " pdb_code_it=[\"3eiy\", \"4hhb\", \"1lds\", \"2ll6\"],\n", + " return_dict=False\n", + " )\n", + "\n", + "# do some transformation\n", + "graphs = [gp.extract_subgraph_from_chains(g, [\"A\"]) for g in graphs]\n", + "\n", + "# Convert to PyG Data format\n", + "convertor = GraphFormatConvertor(src_format=\"nx\", dst_format=\"pyg\")\n", + "graphs = [convertor(g) for g in graphs]\n", + "\n", + "# Create dataset\n", + "ds = ProteinGraphListDataset(root=\".\", data_list=graphs, name=\"list_test\")" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data(edge_index=[2, 173], node_id=[174], coords=[1], name=[1], dist_mat=[1], num_nodes=174)\n", + "Data(edge_index=[2, 140], node_id=[141], coords=[1], name=[1], dist_mat=[1], num_nodes=141)\n", + "Data(edge_index=[2, 96], node_id=[97], coords=[1], name=[1], dist_mat=[1], num_nodes=97)\n", + "Data(edge_index=[2, 147], node_id=[148], coords=[1], name=[1], dist_mat=[1], num_nodes=148)\n" + ] + } + ], + "source": [ + "for i in ds:\n", + " print(i)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataBatch(edge_index=[2, 303], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=306, batch=[306], ptr=[3])\n", + "DataBatch(edge_index=[2, 1009], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=1011, batch=[1011], ptr=[3])\n", + "DataBatch(edge_index=[2, 96], node_id=[1], coords=[1], name=[1], dist_mat=[1], graph_y=[1], node_y=[1], num_nodes=97, batch=[97], ptr=[2])\n" + ] + } + ], + "source": [ + "# Create a dataloader from dataset and inspect a few batches\n", + "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=False)\n", + "for i in dl:\n", + " print(i)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transforms\n", + "\n", + "We can supply various functions to `ProteinGraphDataset` and `InMemoryProteinGraphDataset` to alter the composition of the dataset.\n", + "\n", + "* ``pdb_transform`` (``list(callable)``, optional) - A function that receives a list of paths to the downloaded structures. This provides an entry point to apply pre-processing from bioinformatics tools of your choosing\n", + "\n", + "* ``graph_transformation_funcs``: (``List[Callable]``, optional) List of functions that consume a ``nx.Graph`` and return a ``nx.Graph``. Applied to graphs after construction but before conversion to ``torch_geometric.data.Data``. Defaults to ``None``.\n", + "\n", + "* ``transform`` (``callable``, optional) – A function/transform that takes in a ``torch_geometric.data.Data`` object and returns a transformed version. The data object will be transformed before every access. (default: ``None``)\n", + "\n", + "* ``pre_transform`` (``callable``, optional) – A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: ``None``)\n", + "\n", + "* ``pre_filter`` (``callable,`` optional) – A function that takes in a ``torch_geometric.data.Data`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: ``None``)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List\n", + "import networkx as nx\n", + "from torch_geometric.data import Data\n", + "\n", + "# Create dummy transforms\n", + "def pdb_transform_fn(files: List[str]):\n", + " \"\"\"Transforms raw pdbs prior to computing graphs.\"\"\"\n", + " return\n", + "\n", + "def graph_transform_fn(graph: nx.Graph) -> nx.Graph:\n", + " \"\"\"Transforms graphein nx.Graph prior to conversion to torch_geometric.data.Data.\"\"\"\n", + " return graph\n", + "\n", + "def transform_fn(data: Data) -> Data:\n", + " \"\"\"Transforms torch_geometric.data.Data prior to every access.\"\"\"\n", + " return data\n", + "\n", + "def pre_transform_fn(data: Data) -> Data:\n", + " \"\"\"Transforms torch_geometric.data.Data prior to saving to disk.\"\"\"\n", + " return data\n", + "\n", + "def pre_filter_fn(data: Data) -> bool:\n", + " \"\"\"Takes in a torch_geometric.data.Data and returns True if the data should be included in the dataset.\"\"\"\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n", + "To do so, use the following command: conda install -c pytorch3d pytorch3d\n", + " 0%| | 0/4 [00:00