-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3f8ecc3
commit 5fefa05
Showing
5 changed files
with
387 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 61, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os, time\n", | ||
"import torch\n", | ||
"import torch_geometric\n", | ||
"from datasets.BatchWSI import BatchWSI\n", | ||
"from models.model_graph_mil import *\n", | ||
"device = torch.device('cuda:0')\n", | ||
"\n", | ||
"dataroot = './data/TCGA/BRCA/'\n", | ||
"large_graph_pt = 'TCGA-BH-A0DV-01Z-00-DX1.2F0B5FB3-40F0-4D27-BFAC-390FB9A42B39.pt' # example input\n", | ||
"\n", | ||
"def count_parameters(model):\n", | ||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Graph Data Structure\n", | ||
"- `N`: number of patches\n", | ||
"- `M`: number of edges\n", | ||
"- `centroid`: [N x 2] matrix containing centroids for each patch\n", | ||
"- `edge_index`: [2 x M] matrix containing edges between patches (connected via adjacent spatial coordinates)\n", | ||
"- `edge_latent`: [2 x M] matric containing edges between patches (connected via latent space)\n", | ||
"- `x`: [N x 1024] matrix which uses 1024-dim extracted ResNet features for each iamge patch (features saved for simplicity)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Data(centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], x=[23049, 1024])" | ||
] | ||
}, | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"data = torch.load(os.path.join(dataroot, large_graph_pt))\n", | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In PyTorch Geometric, inference on large graphs is very tractable. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension\n", | ||
"\n", | ||
"This procedure has some crucial advantages over other batching procedures:\n", | ||
"\n", | ||
"- GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.\n", | ||
"\n", | ||
"- There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, i.e., the edges. \n", | ||
"- For more details, see the advanced mini-batching FAQ in: https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"BatchWSI(batch=[46098], centroid=[46098, 2], edge_index=[2, 322686], edge_latent=[4, 161343], ptr=[3], x=[46098, 1024])" | ||
] | ||
}, | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n", | ||
" torch.load(os.path.join(dataroot, large_graph_pt))])\n", | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Inference + Backprop using 23K patches" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 66, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"BatchWSI(batch=[23049], centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], ptr=[2], x=[23049, 1024])" | ||
] | ||
}, | ||
"execution_count": 66, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt))])\n", | ||
"data = data.to(device)\n", | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 67, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Number of Parameters: 1382917\n", | ||
"Time Elapsed: 0.06325 seconds\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n", | ||
"model = PatchGCN_Surv(**model_dict).to(device)\n", | ||
"print(\"Number of Parameters:\", count_parameters(model))\n", | ||
"\n", | ||
"### Example Forward Paas + Gradient Backprop\n", | ||
"start = time.time()\n", | ||
"out = model(x_path=data)\n", | ||
"out[0].backward()\n", | ||
"print('Time Elapsed: %0.5f seconds' % (time.time() - start))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Inference + Backprop using 92K patches" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 50, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"BatchWSI(batch=[92196], centroid=[92196, 2], edge_index=[2, 645372], edge_latent=[8, 161343], ptr=[5], x=[92196, 1024])" | ||
] | ||
}, | ||
"execution_count": 50, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"### Simulating a very large graph (containing 4 subgraphs of 23K patches each)\n", | ||
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n", | ||
" torch.load(os.path.join(dataroot, large_graph_pt)),\n", | ||
" torch.load(os.path.join(dataroot, large_graph_pt)),\n", | ||
" torch.load(os.path.join(dataroot, large_graph_pt))])\n", | ||
"data = data.to(device)\n", | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 55, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Number of Parameters: 1382917\n", | ||
"Time Elapsed: 0.20629 seconds\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n", | ||
"model = PatchGCN_Surv(**model_dict).to(device)\n", | ||
"print(\"Number of Parameters:\", count_parameters(model))\n", | ||
"\n", | ||
"### Example Forward Paas + Gradient Backprop\n", | ||
"start = time.time()\n", | ||
"out = model(x_path=data)\n", | ||
"out[0].backward()\n", | ||
"print('Time Elapsed: %0.5f seconds' % (time.time() - start))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Assuming worst case scenario that every graph has ~100K patches, for a dataset of 1000 WSIs, an epoch would take 3.43 minutes, with 20 epochs taking ~ 1 hour." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import torch_geometric | ||
from typing import List | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_sparse import SparseTensor, cat | ||
import torch_geometric | ||
from torch_geometric.data import Data | ||
|
||
class BatchWSI(torch_geometric.data.Batch): | ||
def __init__(self): | ||
super(BatchWSI, self).__init__() | ||
pass | ||
|
||
@classmethod | ||
def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[], update_cat_dims={}): | ||
r"""Constructs a batch object from a python list holding | ||
:class:`torch_geometric.data.Data` objects. | ||
The assignment vector :obj:`batch` is created on the fly. | ||
Additionally, creates assignment batch vectors for each key in | ||
:obj:`follow_batch`. | ||
Will exclude any keys given in :obj:`exclude_keys`.""" | ||
|
||
keys = list(set(data_list[0].keys) - set(exclude_keys)) | ||
assert 'batch' not in keys and 'ptr' not in keys | ||
|
||
batch = cls() | ||
for key in data_list[0].__dict__.keys(): | ||
if key[:2] != '__' and key[-2:] != '__': | ||
batch[key] = None | ||
|
||
batch.__num_graphs__ = len(data_list) | ||
batch.__data_class__ = data_list[0].__class__ | ||
for key in keys + ['batch']: | ||
batch[key] = [] | ||
batch['ptr'] = [0] | ||
cat_dims = {} | ||
device = None | ||
slices = {key: [0] for key in keys} | ||
cumsum = {key: [0] for key in keys} | ||
num_nodes_list = [] | ||
for i, data in enumerate(data_list): | ||
for key in keys: | ||
item = data[key] | ||
|
||
# Increase values by `cumsum` value. | ||
cum = cumsum[key][-1] | ||
if isinstance(item, Tensor) and item.dtype != torch.bool: | ||
if not isinstance(cum, int) or cum != 0: | ||
item = item + cum | ||
elif isinstance(item, SparseTensor): | ||
value = item.storage.value() | ||
if value is not None and value.dtype != torch.bool: | ||
if not isinstance(cum, int) or cum != 0: | ||
value = value + cum | ||
item = item.set_value(value, layout='coo') | ||
elif isinstance(item, (int, float)): | ||
item = item + cum | ||
|
||
# Gather the size of the `cat` dimension. | ||
size = 1 | ||
|
||
if key in update_cat_dims.keys(): | ||
cat_dim = update_cat_dims[key] | ||
else: | ||
cat_dim = data.__cat_dim__(key, data[key]) | ||
# 0-dimensional tensors have no dimension along which to | ||
# concatenate, so we set `cat_dim` to `None`. | ||
if isinstance(item, Tensor) and item.dim() == 0: | ||
cat_dim = None | ||
|
||
cat_dims[key] = cat_dim | ||
|
||
# Add a batch dimension to items whose `cat_dim` is `None`: | ||
if isinstance(item, Tensor) and cat_dim is None: | ||
cat_dim = 0 # Concatenate along this new batch dimension. | ||
item = item.unsqueeze(0) | ||
device = item.device | ||
elif isinstance(item, Tensor): | ||
size = item.size(cat_dim) | ||
device = item.device | ||
elif isinstance(item, SparseTensor): | ||
size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] | ||
device = item.device() | ||
|
||
batch[key].append(item) # Append item to the attribute list. | ||
|
||
slices[key].append(size + slices[key][-1]) | ||
inc = data.__inc__(key, item) | ||
if isinstance(inc, (tuple, list)): | ||
inc = torch.tensor(inc) | ||
cumsum[key].append(inc + cumsum[key][-1]) | ||
|
||
if key in follow_batch: | ||
if isinstance(size, Tensor): | ||
for j, size in enumerate(size.tolist()): | ||
tmp = f'{key}_{j}_batch' | ||
batch[tmp] = [] if i == 0 else batch[tmp] | ||
batch[tmp].append( | ||
torch.full((size, ), i, dtype=torch.long, | ||
device=device)) | ||
else: | ||
tmp = f'{key}_batch' | ||
batch[tmp] = [] if i == 0 else batch[tmp] | ||
batch[tmp].append( | ||
torch.full((size, ), i, dtype=torch.long, | ||
device=device)) | ||
|
||
if hasattr(data, '__num_nodes__'): | ||
num_nodes_list.append(data.__num_nodes__) | ||
else: | ||
num_nodes_list.append(None) | ||
|
||
num_nodes = data.num_nodes | ||
if num_nodes is not None: | ||
item = torch.full((num_nodes, ), i, dtype=torch.long, | ||
device=device) | ||
batch.batch.append(item) | ||
batch.ptr.append(batch.ptr[-1] + num_nodes) | ||
|
||
batch.batch = None if len(batch.batch) == 0 else batch.batch | ||
batch.ptr = None if len(batch.ptr) == 1 else batch.ptr | ||
batch.__slices__ = slices | ||
batch.__cumsum__ = cumsum | ||
batch.__cat_dims__ = cat_dims | ||
batch.__num_nodes_list__ = num_nodes_list | ||
|
||
ref_data = data_list[0] | ||
for key in batch.keys: | ||
items = batch[key] | ||
item = items[0] | ||
|
||
### <--- Updating Cat Dim | ||
if key in update_cat_dims.keys(): | ||
cat_dim = update_cat_dims[key] | ||
else: | ||
cat_dim = ref_data.__cat_dim__(key, item) | ||
cat_dim = 0 if cat_dim is None else cat_dim | ||
### ---? | ||
if isinstance(item, Tensor): | ||
batch[key] = torch.cat(items, cat_dim) | ||
elif isinstance(item, SparseTensor): | ||
batch[key] = cat(items, cat_dim) | ||
elif isinstance(item, (int, float)): | ||
batch[key] = torch.tensor(items) | ||
|
||
if torch_geometric.is_debug_enabled(): | ||
batch.debug() | ||
|
||
return batch.contiguous() |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters