diff --git a/Inference Benchmark.ipynb b/Inference Benchmark.ipynb new file mode 100644 index 0000000..b720f44 --- /dev/null +++ b/Inference Benchmark.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "import os, time\n", + "import torch\n", + "import torch_geometric\n", + "from datasets.BatchWSI import BatchWSI\n", + "from models.model_graph_mil import *\n", + "device = torch.device('cuda:0')\n", + "\n", + "dataroot = './data/TCGA/BRCA/'\n", + "large_graph_pt = 'TCGA-BH-A0DV-01Z-00-DX1.2F0B5FB3-40F0-4D27-BFAC-390FB9A42B39.pt' # example input\n", + "\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Graph Data Structure\n", + "- `N`: number of patches\n", + "- `M`: number of edges\n", + "- `centroid`: [N x 2] matrix containing centroids for each patch\n", + "- `edge_index`: [2 x M] matrix containing edges between patches (connected via adjacent spatial coordinates)\n", + "- `edge_latent`: [2 x M] matric containing edges between patches (connected via latent space)\n", + "- `x`: [N x 1024] matrix which uses 1024-dim extracted ResNet features for each iamge patch (features saved for simplicity)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Data(centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], x=[23049, 1024])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = torch.load(os.path.join(dataroot, large_graph_pt))\n", + "data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In PyTorch Geometric, inference on large graphs is very tractable. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension\n", + "\n", + "This procedure has some crucial advantages over other batching procedures:\n", + "\n", + "- GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.\n", + "\n", + "- There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, i.e., the edges. \n", + "- For more details, see the advanced mini-batching FAQ in: https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BatchWSI(batch=[46098], centroid=[46098, 2], edge_index=[2, 322686], edge_latent=[4, 161343], ptr=[3], x=[46098, 1024])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n", + " torch.load(os.path.join(dataroot, large_graph_pt))])\n", + "data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference + Backprop using 23K patches" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BatchWSI(batch=[23049], centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], ptr=[2], x=[23049, 1024])" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt))])\n", + "data = data.to(device)\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of Parameters: 1382917\n", + "Time Elapsed: 0.06325 seconds\n" + ] + } + ], + "source": [ + "model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n", + "model = PatchGCN_Surv(**model_dict).to(device)\n", + "print(\"Number of Parameters:\", count_parameters(model))\n", + "\n", + "### Example Forward Paas + Gradient Backprop\n", + "start = time.time()\n", + "out = model(x_path=data)\n", + "out[0].backward()\n", + "print('Time Elapsed: %0.5f seconds' % (time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference + Backprop using 92K patches" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BatchWSI(batch=[92196], centroid=[92196, 2], edge_index=[2, 645372], edge_latent=[8, 161343], ptr=[5], x=[92196, 1024])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "### Simulating a very large graph (containing 4 subgraphs of 23K patches each)\n", + "data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), \n", + " torch.load(os.path.join(dataroot, large_graph_pt)),\n", + " torch.load(os.path.join(dataroot, large_graph_pt)),\n", + " torch.load(os.path.join(dataroot, large_graph_pt))])\n", + "data = data.to(device)\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of Parameters: 1382917\n", + "Time Elapsed: 0.20629 seconds\n" + ] + } + ], + "source": [ + "model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}\n", + "model = PatchGCN_Surv(**model_dict).to(device)\n", + "print(\"Number of Parameters:\", count_parameters(model))\n", + "\n", + "### Example Forward Paas + Gradient Backprop\n", + "start = time.time()\n", + "out = model(x_path=data)\n", + "out[0].backward()\n", + "print('Time Elapsed: %0.5f seconds' % (time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Assuming worst case scenario that every graph has ~100K patches, for a dataset of 1000 WSIs, an epoch would take 3.43 minutes, with 20 epochs taking ~ 1 hour." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/datasets/BatchWSI.py b/datasets/BatchWSI.py new file mode 100644 index 0000000..1297932 --- /dev/null +++ b/datasets/BatchWSI.py @@ -0,0 +1,150 @@ +import torch_geometric +from typing import List + +import torch +from torch import Tensor +from torch_sparse import SparseTensor, cat +import torch_geometric +from torch_geometric.data import Data + +class BatchWSI(torch_geometric.data.Batch): + def __init__(self): + super(BatchWSI, self).__init__() + pass + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[], update_cat_dims={}): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert 'batch' not in keys and 'ptr' not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != '__' and key[-2:] != '__': + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ['batch']: + batch[key] = [] + batch['ptr'] = [0] + cat_dims = {} + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, SparseTensor): + value = item.storage.value() + if value is not None and value.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + value = value + cum + item = item.set_value(value, layout='coo') + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + + if key in update_cat_dims.keys(): + cat_dim = update_cat_dims[key] + else: + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + elif isinstance(item, SparseTensor): + size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] + device = item.device() + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f'{key}_{j}_batch' + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size, ), i, dtype=torch.long, + device=device)) + else: + tmp = f'{key}_batch' + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size, ), i, dtype=torch.long, + device=device)) + + if hasattr(data, '__num_nodes__'): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes, ), i, dtype=torch.long, + device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + + ### <--- Updating Cat Dim + if key in update_cat_dims.keys(): + cat_dim = update_cat_dims[key] + else: + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + ### ---? + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, SparseTensor): + batch[key] = cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + if torch_geometric.is_debug_enabled(): + batch.debug() + + return batch.contiguous() \ No newline at end of file diff --git a/datasets/__pycache__/BatchWSI.cpython-37.pyc b/datasets/__pycache__/BatchWSI.cpython-37.pyc new file mode 100644 index 0000000..d4667d7 Binary files /dev/null and b/datasets/__pycache__/BatchWSI.cpython-37.pyc differ diff --git a/models/__pycache__/model_graph_mil.cpython-37.pyc b/models/__pycache__/model_graph_mil.cpython-37.pyc new file mode 100644 index 0000000..1a5990e Binary files /dev/null and b/models/__pycache__/model_graph_mil.cpython-37.pyc differ diff --git a/models/model_graph_mil.py b/models/model_graph_mil.py index 86ae62c..3edd2db 100644 --- a/models/model_graph_mil.py +++ b/models/model_graph_mil.py @@ -183,7 +183,7 @@ def __init__(self, num_layers=4, edge_agg='spatial', resample=0, self.resample = resample if self.resample > 0: - self.fc = nn.Sequential(*[nn.Dropout(self.resample), nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.25)]) + self.fc = nn.Sequential(*[nn.Dropout(self.resample), nn.Linear(1024, 128), nn.ReLU(), nn.Dropout(0.25)]) else: self.fc = nn.Sequential(*[nn.Linear(1024, 128), nn.ReLU(), nn.Dropout(0.25)])