From ec4a59e5656e3649ff47fd85cd0700d70fb094e6 Mon Sep 17 00:00:00 2001 From: Julian Fong <44014224+julian-fong@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:58:33 -0500 Subject: [PATCH] Add support for AdapterPlus (#746) This PR aims to add support for AdapterPlus Github: https://github.com/visinf/adapter_plus Paper: https://arxiv.org/pdf/2406.06820 Integration of AdapterPlus into the `adapters` library will involve adding new parameters/options to the `BnConfig` Checklist of things that are added/to be added 1. New type of `scaling` called `channel`, in which we add learnable parameters for the channel/input_size dimension 2. New type of `init_weights` called `houlsby`, where the projection matrices $W_{down}$ and $W_{up}$ will be initialized with zero-centered Gaussian with a standard deviation of $10^{-2}$ truncated at 2 standard deviations, and zeros for bias 3. Support for `drop_path`, also known as stochastic depth **ONLY** applicable for vision based tasks using residual networks - located under a new file called `/methods/vision.py` --- .github/workflows/tests_torch.yml | 6 +- docs/classes/adapter_config.rst | 3 + docs/methods.md | 3 +- notebooks/README.md | 1 + notebooks/ViT_AdapterPlus_FineTuning.ipynb | 529 +++++++++++++++++++ setup.py | 4 +- src/adapters/__init__.py | 2 + src/adapters/configuration/adapter_config.py | 30 +- src/adapters/methods/modeling.py | 22 +- 9 files changed, 591 insertions(+), 9 deletions(-) create mode 100644 notebooks/ViT_AdapterPlus_FineTuning.ipynb diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index f7a394ce4a..fd5930ebb6 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -63,7 +63,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] - name: Test run: | make test-adapter-methods @@ -86,7 +86,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] - name: Test run: | make test-adapter-models @@ -109,7 +109,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] pip install conllu seqeval - name: Test Examples run: | diff --git a/docs/classes/adapter_config.rst b/docs/classes/adapter_config.rst index f72c7bbf94..a7ccbec2c9 100644 --- a/docs/classes/adapter_config.rst +++ b/docs/classes/adapter_config.rst @@ -34,6 +34,9 @@ Single (bottleneck) adapters .. autoclass:: adapters.CompacterPlusPlusConfig :members: +.. autoclass:: adapters.AdapterPlusConfig + :members: + Prefix Tuning ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methods.md b/docs/methods.md index 10fa0f8627..302b1973d3 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -42,7 +42,7 @@ A visualization of further configuration options related to the adapter structur - [`DoubleSeqBnConfig`](adapters.DoubleSeqBnConfig), as proposed by [Houlsby et al. (2019)](https://arxiv.org/pdf/1902.00751.pdf) places adapter layers after both the multi-head attention and feed-forward block in each Transformer layer. - [`SeqBnConfig`](adapters.SeqBnConfig), as proposed by [Pfeiffer et al. (2020)](https://arxiv.org/pdf/2005.00052.pdf) places an adapter layer only after the feed-forward block in each Transformer layer. - [`ParBnConfig`](adapters.ParBnConfig), as proposed by [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) places adapter layers in parallel to the original Transformer layers. - +- [`AdapterPlusConfig`](adapters.AdapterPlusConfig), as proposed by [Steitz and Roth (2024)](https://arxiv.org/pdf/2406.06820) places adapter layers adapter layers after the multi-head attention and has channel wise scaling and houlsby weight initialization _Example_: ```python from adapters import BnConfig @@ -56,6 +56,7 @@ _Papers:_ * [Parameter-Efficient Transfer Learning for NLP](https://arxiv.org/pdf/1902.00751.pdf) (Houlsby et al., 2019) * [Simple, Scalable Adaptation for Neural Machine Translation](https://arxiv.org/pdf/1909.08478.pdf) (Bapna and Firat, 2019) * [AdapterFusion: Non-Destructive Task Composition for Transfer Learning](https://aclanthology.org/2021.eacl-main.39.pdf) (Pfeiffer et al., 2021) +* [Adapters Strike Back](https://arxiv.org/pdf/2406.06820) (Steitz and Roth., 2024) * [AdapterHub: A Framework for Adapting Transformers](https://arxiv.org/pdf/2007.07779.pdf) (Pfeiffer et al., 2020) ## Language Adapters - Invertible Adapters diff --git a/notebooks/README.md b/notebooks/README.md index db6bbaa16f..052cdafe4d 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -35,3 +35,4 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | [NER on Wikiann](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) | Evaluating adapters on NER on the wikiann dataset | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) | | [Finetuning Whisper with Adapters](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | Fine Tuning Whisper using LoRA | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | | [Adapter Training with ReFT](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | Fine Tuning using ReFT Adapters | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | +| [ViT Fine-Tuning with AdapterPlus](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | ViT Fine-Tuning with AdapterPlus | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | diff --git a/notebooks/ViT_AdapterPlus_FineTuning.ipynb b/notebooks/ViT_AdapterPlus_FineTuning.ipynb new file mode 100644 index 0000000000..8dfbcd341a --- /dev/null +++ b/notebooks/ViT_AdapterPlus_FineTuning.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-Tuning ViT Classification models using AdapterPlus\n", + "\n", + "In this vision tutorial, we will show how to fine-tune ViT Image models using the `AdapterPlus` Config, which is a bottleneck adapter using the parameters as defined in the `AdapterPlus` paper. For more information on bottleneck adapters, you can visit our basic [tutorial](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) and our docs [page](https://docs.adapterhub.ml/methods#bottleneck-adapters).\n", + "\n", + "You can find the link to the `AdapterPlus` paper by Steitz and Roth [here](https://openaccess.thecvf.com/content/CVPR2024/papers/Steitz_Adapters_Strike_Back_CVPR_2024_paper.pdf) and their GitHub page [here](https://github.com/visinf/adapter_plus)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "Before we can get started, we need to ensure the proper packages are installed. Here's a breakdown of what we need:\n", + "\n", + "- `adapters` to load the model and the adapter configuration\n", + "- `accelerate` for training optimization\n", + "- `evaluate` for metric computation and model evaluation\n", + "- `datasets` to import the datasets for training and evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qq -U adapters datasets accelerate torchvision evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset\n", + "\n", + "For this tutorial, we will be using a light-weight image dataset `cifar100`, which contains 60,000 images with 100 classes. We use the `datasets` library to directly import the dataset and split it into its training and evaluation datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "num_classes = 100\n", + "train_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"train\")\n", + "eval_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"test\")\n", + "\n", + "train_dataset.set_format(\"torch\")\n", + "eval_dataset.set_format(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this tutorial, we will be using the fine_label, to match the number of classes (100) that were used in the paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset[0].keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will now initialize our `ViT` image processor to convert the images into a more friendly format. It will also apply transformations to each image in order to improve the performance during training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model_name_or_path = 'google/vit-base-patch16-224-in21k'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import ViTImageProcessor\n", + "processor = ViTImageProcessor.from_pretrained(model_name_or_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll print out the processor here in order to get an idea of what types of transformations it is applying onto the iamge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Preprocessing\n", + "\n", + "We will pre-process every image as defined in the `processor` above, and add the `label` key which contains our labels" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_image(example):\n", + " image = processor(example[\"img\"], return_tensors='pt')\n", + " image[\"label\"] = example[\"fine_label\"]\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.map(preprocess_image)\n", + "eval_dataset = eval_dataset.map(preprocess_image)\n", + "#remove uneccessary columns\n", + "train_dataset = train_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])\n", + "eval_dataset = eval_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Defining a Datacollator\n", + "\n", + "We'll be using a very simple custom datacollator to help us combine multiple data samples into one batch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "from dataclasses import dataclass\n", + "\n", + "@dataclass\n", + "class DataCollator:\n", + " processor : Any\n", + " def __call__(self, inputs):\n", + "\n", + " pixel_values = [input[\"pixel_values\"].squeeze() for input in inputs]\n", + " labels = [input[\"label\"] for input in inputs]\n", + "\n", + " pixel_values = torch.stack(pixel_values)\n", + " labels = torch.stack(labels)\n", + " return {\n", + " 'pixel_values': pixel_values,\n", + " 'labels': labels,\n", + " }\n", + "\n", + "data_collator = DataCollator(processor = processor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading the `ViT` model and the `AdapterPlusConfig`\n", + "\n", + "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import ViTAdapterModel\n", + "from adapters import AdapterPlusConfig\n", + "\n", + "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n", + "config = AdapterPlusConfig(original_ln_after=True)\n", + "\n", + "model.add_adapter(\"adapterplus_config\", config)\n", + "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n", + "model.train_adapter(\"adapterplus_config\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "Name Architecture #Param %Param Active Train\n", + "--------------------------------------------------------------------------------\n", + "adapterplus_config bottleneck 165,984 0.192 1 1\n", + "--------------------------------------------------------------------------------\n", + "Full model 86,389,248 100.000 0\n", + "================================================================================\n" + ] + } + ], + "source": [ + "print(model.adapter_summary())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation Metrics\n", + "\n", + "We'll use accuracy as our main metric to evaluate the perforce of the reft model on the `cifar100` dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "accuracy = evaluate.load(\"accuracy\")\n", + "\n", + "def compute_metrics(p):\n", + " return accuracy.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Now we are ready to train our model. The same set of hyper-parameters that were used in the original paper will be re-used, except for the number of training epochs that the model will be trained on. You can always adjust the number of epochs yourself or any other hyperparameter in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import AdapterTrainer\n", + "from transformers import TrainingArguments\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir='./training_results',\n", + " eval_strategy='epoch',\n", + " learning_rate=10e-3,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", + " num_train_epochs=5,\n", + " weight_decay=10e-4,\n", + " report_to = \"none\",\n", + " remove_unused_columns=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = AdapterTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " tokenizer=processor,\n", + " compute_metrics = compute_metrics\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
Epoch | \n", + "Training Loss | \n", + "Validation Loss | \n", + "Accuracy | \n", + "
---|---|---|---|
1 | \n", + "0.523500 | \n", + "0.351386 | \n", + "0.897700 | \n", + "
2 | \n", + "0.264200 | \n", + "0.329842 | \n", + "0.906500 | \n", + "
3 | \n", + "0.184700 | \n", + "0.324673 | \n", + "0.911400 | \n", + "
4 | \n", + "0.135400 | \n", + "0.347025 | \n", + "0.909100 | \n", + "
5 | \n", + "0.114700 | \n", + "0.349541 | \n", + "0.910900 | \n", + "
6 | \n", + "0.074300 | \n", + "0.370867 | \n", + "0.909600 | \n", + "
7 | \n", + "0.058100 | \n", + "0.373732 | \n", + "0.912400 | \n", + "
8 | \n", + "0.039300 | \n", + "0.376220 | \n", + "0.913900 | \n", + "
9 | \n", + "0.028200 | \n", + "0.381238 | \n", + "0.913300 | \n", + "
10 | \n", + "0.021600 | \n", + "0.380825 | \n", + "0.912400 | \n", + "
"
+ ],
+ "text/plain": [
+ "