Skip to content

Commit

Permalink
Adding inference benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
miccai2021anon committed May 21, 2021
1 parent 3f8ecc3 commit 5fefa05
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 1 deletion.
236 changes: 236 additions & 0 deletions Inference Benchmark.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"import os, time\n",
"import torch\n",
"import torch_geometric\n",
"from datasets.BatchWSI import BatchWSI\n",
"from models.model_graph_mil import *\n",
"device = torch.device('cuda:0')\n",
"\n",
"dataroot = './data/TCGA/BRCA/'\n",
"large_graph_pt = 'TCGA-BH-A0DV-01Z-00-DX1.2F0B5FB3-40F0-4D27-BFAC-390FB9A42B39.pt' # example input\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph Data Structure\n",
"- `N`: number of patches\n",
"- `M`: number of edges\n",
"- `centroid`: [N x 2] matrix containing centroids for each patch\n",
"- `edge_index`: [2 x M] matrix containing edges between patches (connected via adjacent spatial coordinates)\n",
"- `edge_latent`: [2 x M] matric containing edges between patches (connected via latent space)\n",
"- `x`: [N x 1024] matrix which uses 1024-dim extracted ResNet features for each iamge patch (features saved for simplicity)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Data(centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], x=[23049, 1024])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = torch.load(os.path.join(dataroot, large_graph_pt))\n",
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In PyTorch Geometric, inference on large graphs is very tractable. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension\n",
"\n",
"This procedure has some crucial advantages over other batching procedures:\n",
"\n",
"- GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.\n",
"\n",
"- There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, i.e., the edges. \n",
"- For more details, see the advanced mini-batching FAQ in: https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BatchWSI(batch=[46098], centroid=[46098, 2], edge_index=[2, 322686], edge_latent=[4, 161343], ptr=[3], x=[46098, 1024])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n",
" torch.load(os.path.join(dataroot, large_graph_pt))])\n",
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Inference + Backprop using 23K patches"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BatchWSI(batch=[23049], centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], ptr=[2], x=[23049, 1024])"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt))])\n",
"data = data.to(device)\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Parameters: 1382917\n",
"Time Elapsed: 0.06325 seconds\n"
]
}
],
"source": [
"model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n",
"model = PatchGCN_Surv(**model_dict).to(device)\n",
"print(\"Number of Parameters:\", count_parameters(model))\n",
"\n",
"### Example Forward Paas + Gradient Backprop\n",
"start = time.time()\n",
"out = model(x_path=data)\n",
"out[0].backward()\n",
"print('Time Elapsed: %0.5f seconds' % (time.time() - start))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Inference + Backprop using 92K patches"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BatchWSI(batch=[92196], centroid=[92196, 2], edge_index=[2, 645372], edge_latent=[8, 161343], ptr=[5], x=[92196, 1024])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"### Simulating a very large graph (containing 4 subgraphs of 23K patches each)\n",
"data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n",
" torch.load(os.path.join(dataroot, large_graph_pt)),\n",
" torch.load(os.path.join(dataroot, large_graph_pt)),\n",
" torch.load(os.path.join(dataroot, large_graph_pt))])\n",
"data = data.to(device)\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Parameters: 1382917\n",
"Time Elapsed: 0.20629 seconds\n"
]
}
],
"source": [
"model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n",
"model = PatchGCN_Surv(**model_dict).to(device)\n",
"print(\"Number of Parameters:\", count_parameters(model))\n",
"\n",
"### Example Forward Paas + Gradient Backprop\n",
"start = time.time()\n",
"out = model(x_path=data)\n",
"out[0].backward()\n",
"print('Time Elapsed: %0.5f seconds' % (time.time() - start))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assuming worst case scenario that every graph has ~100K patches, for a dataset of 1000 WSIs, an epoch would take 3.43 minutes, with 20 epochs taking ~ 1 hour."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
150 changes: 150 additions & 0 deletions datasets/BatchWSI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch_geometric
from typing import List

import torch
from torch import Tensor
from torch_sparse import SparseTensor, cat
import torch_geometric
from torch_geometric.data import Data

class BatchWSI(torch_geometric.data.Batch):
def __init__(self):
super(BatchWSI, self).__init__()
pass

@classmethod
def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[], update_cat_dims={}):
r"""Constructs a batch object from a python list holding
:class:`torch_geometric.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly.
Additionally, creates assignment batch vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`."""

keys = list(set(data_list[0].keys) - set(exclude_keys))
assert 'batch' not in keys and 'ptr' not in keys

batch = cls()
for key in data_list[0].__dict__.keys():
if key[:2] != '__' and key[-2:] != '__':
batch[key] = None

batch.__num_graphs__ = len(data_list)
batch.__data_class__ = data_list[0].__class__
for key in keys + ['batch']:
batch[key] = []
batch['ptr'] = [0]
cat_dims = {}
device = None
slices = {key: [0] for key in keys}
cumsum = {key: [0] for key in keys}
num_nodes_list = []
for i, data in enumerate(data_list):
for key in keys:
item = data[key]

# Increase values by `cumsum` value.
cum = cumsum[key][-1]
if isinstance(item, Tensor) and item.dtype != torch.bool:
if not isinstance(cum, int) or cum != 0:
item = item + cum
elif isinstance(item, SparseTensor):
value = item.storage.value()
if value is not None and value.dtype != torch.bool:
if not isinstance(cum, int) or cum != 0:
value = value + cum
item = item.set_value(value, layout='coo')
elif isinstance(item, (int, float)):
item = item + cum

# Gather the size of the `cat` dimension.
size = 1

if key in update_cat_dims.keys():
cat_dim = update_cat_dims[key]
else:
cat_dim = data.__cat_dim__(key, data[key])
# 0-dimensional tensors have no dimension along which to
# concatenate, so we set `cat_dim` to `None`.
if isinstance(item, Tensor) and item.dim() == 0:
cat_dim = None

cat_dims[key] = cat_dim

# Add a batch dimension to items whose `cat_dim` is `None`:
if isinstance(item, Tensor) and cat_dim is None:
cat_dim = 0 # Concatenate along this new batch dimension.
item = item.unsqueeze(0)
device = item.device
elif isinstance(item, Tensor):
size = item.size(cat_dim)
device = item.device
elif isinstance(item, SparseTensor):
size = torch.tensor(item.sizes())[torch.tensor(cat_dim)]
device = item.device()

batch[key].append(item) # Append item to the attribute list.

slices[key].append(size + slices[key][-1])
inc = data.__inc__(key, item)
if isinstance(inc, (tuple, list)):
inc = torch.tensor(inc)
cumsum[key].append(inc + cumsum[key][-1])

if key in follow_batch:
if isinstance(size, Tensor):
for j, size in enumerate(size.tolist()):
tmp = f'{key}_{j}_batch'
batch[tmp] = [] if i == 0 else batch[tmp]
batch[tmp].append(
torch.full((size, ), i, dtype=torch.long,
device=device))
else:
tmp = f'{key}_batch'
batch[tmp] = [] if i == 0 else batch[tmp]
batch[tmp].append(
torch.full((size, ), i, dtype=torch.long,
device=device))

if hasattr(data, '__num_nodes__'):
num_nodes_list.append(data.__num_nodes__)
else:
num_nodes_list.append(None)

num_nodes = data.num_nodes
if num_nodes is not None:
item = torch.full((num_nodes, ), i, dtype=torch.long,
device=device)
batch.batch.append(item)
batch.ptr.append(batch.ptr[-1] + num_nodes)

batch.batch = None if len(batch.batch) == 0 else batch.batch
batch.ptr = None if len(batch.ptr) == 1 else batch.ptr
batch.__slices__ = slices
batch.__cumsum__ = cumsum
batch.__cat_dims__ = cat_dims
batch.__num_nodes_list__ = num_nodes_list

ref_data = data_list[0]
for key in batch.keys:
items = batch[key]
item = items[0]

### <--- Updating Cat Dim
if key in update_cat_dims.keys():
cat_dim = update_cat_dims[key]
else:
cat_dim = ref_data.__cat_dim__(key, item)
cat_dim = 0 if cat_dim is None else cat_dim
### ---?
if isinstance(item, Tensor):
batch[key] = torch.cat(items, cat_dim)
elif isinstance(item, SparseTensor):
batch[key] = cat(items, cat_dim)
elif isinstance(item, (int, float)):
batch[key] = torch.tensor(items)

if torch_geometric.is_debug_enabled():
batch.debug()

return batch.contiguous()
Binary file added datasets/__pycache__/BatchWSI.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/model_graph_mil.cpython-37.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion models/model_graph_mil.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(self, num_layers=4, edge_agg='spatial', resample=0,
self.resample = resample

if self.resample > 0:
self.fc = nn.Sequential(*[nn.Dropout(self.resample), nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.25)])
self.fc = nn.Sequential(*[nn.Dropout(self.resample), nn.Linear(1024, 128), nn.ReLU(), nn.Dropout(0.25)])
else:
self.fc = nn.Sequential(*[nn.Linear(1024, 128), nn.ReLU(), nn.Dropout(0.25)])

Expand Down

0 comments on commit 5fefa05

Please sign in to comment.