From 0a37a5807535f1f62c3aef581ce5ecd1ec739c81 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Fri, 15 Jan 2021 15:28:11 -0500 Subject: [PATCH] Gram matrix example using Colab --- examples/NNGeometry_Gram_matrix_example.ipynb | 581 ++++++++++++++++++ 1 file changed, 581 insertions(+) create mode 100644 examples/NNGeometry_Gram_matrix_example.ipynb diff --git a/examples/NNGeometry_Gram_matrix_example.ipynb b/examples/NNGeometry_Gram_matrix_example.ipynb new file mode 100644 index 0000000..ed39c2b --- /dev/null +++ b/examples/NNGeometry_Gram_matrix_example.ipynb @@ -0,0 +1,581 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "NNGeometry Gram matrix example", + "provenance": [], + "authorship_tag": "ABX9TyOuQUTegwJIwTfvuP5m7aOw", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "1a16d27bd47e4b96a810bb5c63ff6b03": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_fb4ab718be794f40a5dd929aeadbc816", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_adbb0b6d0f7f4480b02d92432cf24165", + "IPY_MODEL_9cd955730f7a40578b1d7beb547aeb0a" + ] + } + }, + "fb4ab718be794f40a5dd929aeadbc816": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "adbb0b6d0f7f4480b02d92432cf24165": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_00aa5f0f3ef24734b3f2140e09dddaad", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_2551523aa11641c1940f096ff2816e8e" + } + }, + "9cd955730f7a40578b1d7beb547aeb0a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_684309b0605c4a178dc6d5fc46e7902f", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 170500096/? [00:04<00:00, 40773007.44it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_af312f18b5b94b619153a6f48476f89c" + } + }, + "00aa5f0f3ef24734b3f2140e09dddaad": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "2551523aa11641c1940f096ff2816e8e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "684309b0605c4a178dc6d5fc46e7902f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "af312f18b5b94b619153a6f48476f89c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kZvhzZJ6k4Tl", + "outputId": "429c52db-9a83-46fc-86c2-9821747f1cc1" + }, + "source": [ + "!pip install git+https://github.com/tfjgeorge/nngeometry.git" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/tfjgeorge/nngeometry.git\n", + " Cloning https://github.com/tfjgeorge/nngeometry.git to /tmp/pip-req-build-o2thtvxx\n", + " Running command git clone -q https://github.com/tfjgeorge/nngeometry.git /tmp/pip-req-build-o2thtvxx\n", + "Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from nngeometry==0.1) (1.7.0+cu101)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.0.0->nngeometry==0.1) (0.8)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.0.0->nngeometry==0.1) (3.7.4.3)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.0.0->nngeometry==0.1) (0.16.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch>=1.0.0->nngeometry==0.1) (1.19.5)\n", + "Building wheels for collected packages: nngeometry\n", + " Building wheel for nngeometry (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nngeometry: filename=nngeometry-0.1-cp36-none-any.whl size=20894 sha256=259910e9b81cf140249d080dd17cbf3dbf356063f1440c605eb60cd62a61cff0\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-kqdm1cui/wheels/0e/82/b3/42a1a59c9ab5dcb2a16c557430ef6bbdce07fe33ac46af6beb\n", + "Successfully built nngeometry\n", + "Installing collected packages: nngeometry\n", + "Successfully installed nngeometry-0.1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IcEfRHHAnjas" + }, + "source": [ + "# PyTorch dataloader and model definition\n", + "\n", + "In the next cells, this is just your regular model and dataloader definition using standard PyTorch classes. Nothing here is specific to NNGeometry.\n", + "\n", + "We now start by defining our model. We here use a ResNet18." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AndrmMIik5mR" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class BasicBlock(nn.Module):\n", + " expansion = 1\n", + "\n", + " def __init__(self, in_planes, planes, stride=1):\n", + " super(BasicBlock, self).__init__()\n", + " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1)\n", + " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)\n", + "\n", + " self.shortcut = nn.Sequential()\n", + " if stride != 1 or in_planes != self.expansion*planes:\n", + " self.shortcut = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)\n", + "\n", + " def forward(self, x):\n", + " out = F.relu(self.conv1(x))\n", + " out = self.conv2(out)\n", + " out += self.shortcut(x)\n", + " out = F.relu(out)\n", + " return out\n", + "\n", + "class ResNet(nn.Module):\n", + " def __init__(self, block, num_blocks, num_classes=10):\n", + " super(ResNet, self).__init__()\n", + " self.in_planes = 64\n", + "\n", + " self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)\n", + " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", + " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", + " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", + " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", + " self.linear = nn.Linear(512*block.expansion, num_classes)\n", + "\n", + " def _make_layer(self, block, planes, num_blocks, stride):\n", + " strides = [stride] + [1]*(num_blocks-1)\n", + " layers = []\n", + " for stride in strides:\n", + " layers.append(block(self.in_planes, planes, stride))\n", + " self.in_planes = planes * block.expansion\n", + " return nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(x)\n", + " out = F.relu(out)\n", + " out = self.layer1(out)\n", + " out = self.layer2(out)\n", + " out = self.layer3(out)\n", + " out = self.layer4(out)\n", + " out = F.avg_pool2d(out, 4)\n", + " out = out.view(out.size(0), -1)\n", + " out = self.linear(out)\n", + " return out\n", + "\n", + "\n", + "def ResNet18():\n", + " return ResNet(BasicBlock, [2,2,2,2])\n", + "\n", + "model = ResNet18().cuda()" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cQ8OylUVtPA3" + }, + "source": [ + "Next, we define the dataloader on which we compute the Gram matrix. Notice the specifics:\n", + "\n", + "- in the `Dataloader` instantiation, we pass `shuffle=False` so that examples in the Gram matrix are arranged in a deterministic way, i.e. the first example in the Gram matrix is the first example in the Dataloader and so on.\n", + "- We used a subset of 100 examples of the original test set, since the Gram matrix grows as $n^2$ with $n=$#examples.\n", + "- In order to improve performance, we copied the dataset into GPU memory using the `to_tensordataset` function." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "1a16d27bd47e4b96a810bb5c63ff6b03", + "fb4ab718be794f40a5dd929aeadbc816", + "adbb0b6d0f7f4480b02d92432cf24165", + "9cd955730f7a40578b1d7beb547aeb0a", + "00aa5f0f3ef24734b3f2140e09dddaad", + "2551523aa11641c1940f096ff2816e8e", + "684309b0605c4a178dc6d5fc46e7902f", + "af312f18b5b94b619153a6f48476f89c" + ] + }, + "id": "ebSMtxcMn799", + "outputId": "5c52c2e4-eb32-4dc0-de25-e2b0d2e99e84" + }, + "source": [ + "from torch.utils.data import DataLoader, TensorDataset, Subset\n", + "\n", + "import torchvision.transforms as transforms\n", + "from torchvision.datasets import CIFAR10\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + "])\n", + "\n", + "testset = Subset(CIFAR10(root='/tmp', train=False, download=True,\n", + " transform=transform), range(100))\n", + "\n", + "def to_tensordataset(dataset):\n", + " d = next(iter(DataLoader(dataset,\n", + " batch_size=len(dataset))))\n", + " return TensorDataset(d[0].to('cuda'), d[1].to('cuda'))\n", + "\n", + "testloader = DataLoader(to_tensordataset(testset), batch_size=100, shuffle=False)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar-10-python.tar.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1a16d27bd47e4b96a810bb5c63ff6b03", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting /tmp/cifar-10-python.tar.gz to /tmp\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YALYiOTno-Ye" + }, + "source": [ + "Now that we are done with everything on the PyTorch side, let's get to NNGeometry !\n", + "\n", + "# Computing a Gram matrix" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jttrl2Dko_Lh", + "outputId": "1f48c4f4-7a73-434e-9f3e-b058d09a06ce" + }, + "source": [ + "from nngeometry.generator import Jacobian\n", + "from nngeometry.object import FMatDense\n", + "\n", + "generator = Jacobian(model, testloader, n_output=10)\n", + "K = FMatDense(generator)" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V7JAsCJ0sZPY" + }, + "source": [ + "`K` is a FMatDense object, we can convert to a PyTorch tensor with the following:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fJqr-mdcsBPZ" + }, + "source": [ + "K_torch = K.get_dense_tensor()" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UT3EMTB-syRL" + }, + "source": [ + "`K_torch` is arranged as a 10 x 100 x 10 x 100 tensor since we are here using a 10 classes task with 100 examples" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h9pKTm9EssO3", + "outputId": "884292c7-49ce-4546-9b94-190e88d4636b" + }, + "source": [ + "K_torch.size()" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([10, 100, 10, 100])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + } + ] +} \ No newline at end of file