From ec4a59e5656e3649ff47fd85cd0700d70fb094e6 Mon Sep 17 00:00:00 2001 From: Julian Fong <44014224+julian-fong@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:58:33 -0500 Subject: [PATCH] Add support for AdapterPlus (#746) This PR aims to add support for AdapterPlus Github: https://github.com/visinf/adapter_plus Paper: https://arxiv.org/pdf/2406.06820 Integration of AdapterPlus into the `adapters` library will involve adding new parameters/options to the `BnConfig` Checklist of things that are added/to be added 1. New type of `scaling` called `channel`, in which we add learnable parameters for the channel/input_size dimension 2. New type of `init_weights` called `houlsby`, where the projection matrices $W_{down}$ and $W_{up}$ will be initialized with zero-centered Gaussian with a standard deviation of $10^{-2}$ truncated at 2 standard deviations, and zeros for bias 3. Support for `drop_path`, also known as stochastic depth **ONLY** applicable for vision based tasks using residual networks - located under a new file called `/methods/vision.py` --- .github/workflows/tests_torch.yml | 6 +- docs/classes/adapter_config.rst | 3 + docs/methods.md | 3 +- notebooks/README.md | 1 + notebooks/ViT_AdapterPlus_FineTuning.ipynb | 529 +++++++++++++++++++ setup.py | 4 +- src/adapters/__init__.py | 2 + src/adapters/configuration/adapter_config.py | 30 +- src/adapters/methods/modeling.py | 22 +- 9 files changed, 591 insertions(+), 9 deletions(-) create mode 100644 notebooks/ViT_AdapterPlus_FineTuning.ipynb diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index f7a394ce4a..fd5930ebb6 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -63,7 +63,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] - name: Test run: | make test-adapter-methods @@ -86,7 +86,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] - name: Test run: | make test-adapter-models @@ -109,7 +109,7 @@ jobs: - name: Install run: | pip install torch==2.3 - pip install .[sklearn,testing,sentencepiece] + pip install .[sklearn,testing,sentencepiece,torchvision] pip install conllu seqeval - name: Test Examples run: | diff --git a/docs/classes/adapter_config.rst b/docs/classes/adapter_config.rst index f72c7bbf94..a7ccbec2c9 100644 --- a/docs/classes/adapter_config.rst +++ b/docs/classes/adapter_config.rst @@ -34,6 +34,9 @@ Single (bottleneck) adapters .. autoclass:: adapters.CompacterPlusPlusConfig :members: +.. autoclass:: adapters.AdapterPlusConfig + :members: + Prefix Tuning ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methods.md b/docs/methods.md index 10fa0f8627..302b1973d3 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -42,7 +42,7 @@ A visualization of further configuration options related to the adapter structur - [`DoubleSeqBnConfig`](adapters.DoubleSeqBnConfig), as proposed by [Houlsby et al. (2019)](https://arxiv.org/pdf/1902.00751.pdf) places adapter layers after both the multi-head attention and feed-forward block in each Transformer layer. - [`SeqBnConfig`](adapters.SeqBnConfig), as proposed by [Pfeiffer et al. (2020)](https://arxiv.org/pdf/2005.00052.pdf) places an adapter layer only after the feed-forward block in each Transformer layer. - [`ParBnConfig`](adapters.ParBnConfig), as proposed by [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) places adapter layers in parallel to the original Transformer layers. - +- [`AdapterPlusConfig`](adapters.AdapterPlusConfig), as proposed by [Steitz and Roth (2024)](https://arxiv.org/pdf/2406.06820) places adapter layers adapter layers after the multi-head attention and has channel wise scaling and houlsby weight initialization _Example_: ```python from adapters import BnConfig @@ -56,6 +56,7 @@ _Papers:_ * [Parameter-Efficient Transfer Learning for NLP](https://arxiv.org/pdf/1902.00751.pdf) (Houlsby et al., 2019) * [Simple, Scalable Adaptation for Neural Machine Translation](https://arxiv.org/pdf/1909.08478.pdf) (Bapna and Firat, 2019) * [AdapterFusion: Non-Destructive Task Composition for Transfer Learning](https://aclanthology.org/2021.eacl-main.39.pdf) (Pfeiffer et al., 2021) +* [Adapters Strike Back](https://arxiv.org/pdf/2406.06820) (Steitz and Roth., 2024) * [AdapterHub: A Framework for Adapting Transformers](https://arxiv.org/pdf/2007.07779.pdf) (Pfeiffer et al., 2020) ## Language Adapters - Invertible Adapters diff --git a/notebooks/README.md b/notebooks/README.md index db6bbaa16f..052cdafe4d 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -35,3 +35,4 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | [NER on Wikiann](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) | Evaluating adapters on NER on the wikiann dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) | | [Finetuning Whisper with Adapters](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | Fine Tuning Whisper using LoRA | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | | [Adapter Training with ReFT](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | Fine Tuning using ReFT Adapters | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | +| [ViT Fine-Tuning with AdapterPlus](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | ViT Fine-Tuning with AdapterPlus | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | diff --git a/notebooks/ViT_AdapterPlus_FineTuning.ipynb b/notebooks/ViT_AdapterPlus_FineTuning.ipynb new file mode 100644 index 0000000000..8dfbcd341a --- /dev/null +++ b/notebooks/ViT_AdapterPlus_FineTuning.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-Tuning ViT Classification models using AdapterPlus\n", + "\n", + "In this vision tutorial, we will show how to fine-tune ViT Image models using the `AdapterPlus` Config, which is a bottleneck adapter using the parameters as defined in the `AdapterPlus` paper. For more information on bottleneck adapters, you can visit our basic [tutorial](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) and our docs [page](https://docs.adapterhub.ml/methods#bottleneck-adapters).\n", + "\n", + "You can find the link to the `AdapterPlus` paper by Steitz and Roth [here](https://openaccess.thecvf.com/content/CVPR2024/papers/Steitz_Adapters_Strike_Back_CVPR_2024_paper.pdf) and their GitHub page [here](https://github.com/visinf/adapter_plus)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "Before we can get started, we need to ensure the proper packages are installed. Here's a breakdown of what we need:\n", + "\n", + "- `adapters` to load the model and the adapter configuration\n", + "- `accelerate` for training optimization\n", + "- `evaluate` for metric computation and model evaluation\n", + "- `datasets` to import the datasets for training and evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qq -U adapters datasets accelerate torchvision evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset\n", + "\n", + "For this tutorial, we will be using a light-weight image dataset `cifar100`, which contains 60,000 images with 100 classes. We use the `datasets` library to directly import the dataset and split it into its training and evaluation datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "num_classes = 100\n", + "train_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"train\")\n", + "eval_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"test\")\n", + "\n", + "train_dataset.set_format(\"torch\")\n", + "eval_dataset.set_format(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this tutorial, we will be using the fine_label, to match the number of classes (100) that were used in the paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset[0].keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will now initialize our `ViT` image processor to convert the images into a more friendly format. It will also apply transformations to each image in order to improve the performance during training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model_name_or_path = 'google/vit-base-patch16-224-in21k'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import ViTImageProcessor\n", + "processor = ViTImageProcessor.from_pretrained(model_name_or_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll print out the processor here in order to get an idea of what types of transformations it is applying onto the iamge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Preprocessing\n", + "\n", + "We will pre-process every image as defined in the `processor` above, and add the `label` key which contains our labels" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_image(example):\n", + " image = processor(example[\"img\"], return_tensors='pt')\n", + " image[\"label\"] = example[\"fine_label\"]\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.map(preprocess_image)\n", + "eval_dataset = eval_dataset.map(preprocess_image)\n", + "#remove uneccessary columns\n", + "train_dataset = train_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])\n", + "eval_dataset = eval_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Defining a Datacollator\n", + "\n", + "We'll be using a very simple custom datacollator to help us combine multiple data samples into one batch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "from dataclasses import dataclass\n", + "\n", + "@dataclass\n", + "class DataCollator:\n", + " processor : Any\n", + " def __call__(self, inputs):\n", + "\n", + " pixel_values = [input[\"pixel_values\"].squeeze() for input in inputs]\n", + " labels = [input[\"label\"] for input in inputs]\n", + "\n", + " pixel_values = torch.stack(pixel_values)\n", + " labels = torch.stack(labels)\n", + " return {\n", + " 'pixel_values': pixel_values,\n", + " 'labels': labels,\n", + " }\n", + "\n", + "data_collator = DataCollator(processor = processor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading the `ViT` model and the `AdapterPlusConfig`\n", + "\n", + "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import ViTAdapterModel\n", + "from adapters import AdapterPlusConfig\n", + "\n", + "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n", + "config = AdapterPlusConfig(original_ln_after=True)\n", + "\n", + "model.add_adapter(\"adapterplus_config\", config)\n", + "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n", + "model.train_adapter(\"adapterplus_config\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "Name Architecture #Param %Param Active Train\n", + "--------------------------------------------------------------------------------\n", + "adapterplus_config bottleneck 165,984 0.192 1 1\n", + "--------------------------------------------------------------------------------\n", + "Full model 86,389,248 100.000 0\n", + "================================================================================\n" + ] + } + ], + "source": [ + "print(model.adapter_summary())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation Metrics\n", + "\n", + "We'll use accuracy as our main metric to evaluate the perforce of the reft model on the `cifar100` dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "accuracy = evaluate.load(\"accuracy\")\n", + "\n", + "def compute_metrics(p):\n", + " return accuracy.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Now we are ready to train our model. The same set of hyper-parameters that were used in the original paper will be re-used, except for the number of training epochs that the model will be trained on. You can always adjust the number of epochs yourself or any other hyperparameter in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import AdapterTrainer\n", + "from transformers import TrainingArguments\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir='./training_results',\n", + " eval_strategy='epoch',\n", + " learning_rate=10e-3,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", + " num_train_epochs=5,\n", + " weight_decay=10e-4,\n", + " report_to = \"none\",\n", + " remove_unused_columns=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = AdapterTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " tokenizer=processor,\n", + " compute_metrics = compute_metrics\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [7820/7820 2:44:24, Epoch 10/10]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation LossAccuracy
10.5235000.3513860.897700
20.2642000.3298420.906500
30.1847000.3246730.911400
40.1354000.3470250.909100
50.1147000.3495410.910900
60.0743000.3708670.909600
70.0581000.3737320.912400
80.0393000.3762200.913900
90.0282000.3812380.913300
100.0216000.3808250.912400

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=7820, training_loss=0.13645367293101748, metrics={'train_runtime': 9867.7777, 'train_samples_per_second': 50.67, 'train_steps_per_second': 0.792, 'total_flos': 3.9121684697088e+19, 'train_loss': 0.13645367293101748, 'epoch': 10.0})" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer = AdapterTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " tokenizer=processor,\n", + " compute_metrics = compute_metrics\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference\n", + "\n", + "Now, we'll use our `adapters` trained model to classify some new images!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import Softmax\n", + "#select a random sample from the evaluation dataset\n", + "image = eval_dataset.select([0])\n", + "logits = model(image['pixel_values'].squeeze(0))\n", + "softmax = Softmax(0)\n", + "prediction = torch.argmax(softmax(logits.logits.squeeze()))\n", + "\n", + "prediction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our prediction is the 49th class which corresponds to the 'mountain' label" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Saving the adapter model\n", + "\n", + "If you would like to save your model or push it to HuggingFace, you can always do so with the below code. Make sure to sign in before you do" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.push_adapter_to_hub(\n", + " \"cifar100-adapterplus_config\",\n", + " \"adapterplus_config\",\n", + " datasets_tag=\"cifar100\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "adapter_hub", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 3e4a448379..6fd87f9ca9 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "sphinx-multiversion==0.2.4", "timeout-decorator", "torch", + "torchvision", "transformers~=4.45.2", ] @@ -81,8 +82,8 @@ def deps_list(*pkgs): extras["sklearn"] = deps_list("scikit-learn") extras["torch"] = deps_list("torch", "accelerate") - extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") +extras["torchvision"] = deps_list("torchvision") extras["testing"] = deps_list( "pytest", "pytest-rich", @@ -104,6 +105,7 @@ def deps_list(*pkgs): "beautifulsoup4", "pillow", "accelerate", + "torchvision", ) extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython") diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index d64211495b..a917828e72 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -41,6 +41,7 @@ "DEFAULT_ADAPTERFUSION_CONFIG", "AdapterConfig", "AdapterFusionConfig", + "AdapterPlusConfig", "BnConfig", "CompacterConfig", "CompacterPlusPlusConfig", @@ -157,6 +158,7 @@ DEFAULT_ADAPTERFUSION_CONFIG, AdapterConfig, AdapterFusionConfig, + AdapterPlusConfig, BnConfig, CompacterConfig, CompacterPlusPlusConfig, diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 33521fc71b..0f2eec2162 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -156,13 +156,14 @@ class BnConfig(AdapterConfig): ln_after (:obj:`bool`, optional): If True, add a new layer normalization after the adapter bottleneck. Defaults to False. init_weights (:obj:`str`, optional): Initialization method for the weights of the adapter modules. - Currently, this can be either "bert" (default) or "mam_adapter". + Currently, this can be either "bert" (default) or "mam_adapter" or "houlsby". is_parallel (:obj:`bool`, optional): If True, apply adapter transformations in parallel. By default (False), sequential application is used. scaling (:obj:`float` or :obj:`str`, optional): Scaling factor to use for scaled addition of adapter outputs as done by He et al. (2021). Can be either a - constant factor (float) or the string "learned", in which case the scaling factor is learned. Defaults to - 1.0. + constant factor (float), or the string "learned", in which case the scaling factor is learned, or the string + "channel", in which case we initialize a scaling vector of the channel shape that is then learned. + Defaults to 1.0. use_gating (:obj:`bool`, optional): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. @@ -213,6 +214,10 @@ class BnConfig(AdapterConfig): phm_bias (:obj:`bool`, optional): If True the down and up projection PHMLayer has a bias term. If `phm_layer` is False this is ignored. Defaults to True + stochastic_depth (:obj:`float`, optional): + This value specifies the probability of the model dropping entire layers during + training. This parameter should be only used for vision based tasks involving + residual networks. """ # Required options @@ -250,6 +255,7 @@ class BnConfig(AdapterConfig): hypercomplex_nonlinearity: Optional[str] = "glorot-uniform" phm_rank: Optional[int] = 1 phm_bias: Optional[bool] = True + stochastic_depth: Optional[float] = 0.0 # We want to emulate a simple form of immutability while keeping the ability to add custom attributes. # Therefore, we don't allow changing attribute values if set once. @@ -364,6 +370,24 @@ class ParBnConfig(BnConfig): scaling: Union[float, str] = 4.0 +@dataclass(eq=False) +class AdapterPlusConfig(BnConfig): + """ + The AdapterPlus config architecture proposed by Jan-Martin O, Steitz and Stefan Roth. See https://arxiv.org/pdf/2406.06820 + """ + + original_ln_after: bool = False + residual_before_ln: bool = True + stochastic_depth: float = 0.1 + init_weights: str = "houlsby" + scaling: Union[float, str] = "channel" + + mh_adapter: bool = False + output_adapter: bool = True + reduction_factor: Union[float, Mapping] = 96 + non_linearity: str = "gelu" + + @dataclass(eq=False) class PrefixTuningConfig(AdapterConfig): """ diff --git a/src/adapters/methods/modeling.py b/src/adapters/methods/modeling.py index a51f5ef078..d1e19c257b 100644 --- a/src/adapters/methods/modeling.py +++ b/src/adapters/methods/modeling.py @@ -4,6 +4,7 @@ from torch import nn from transformers.activations import get_activation +from transformers.utils.import_utils import is_torchvision_available from ..configuration import AdapterFusionConfig, BnConfig from ..context import ForwardContext @@ -99,6 +100,8 @@ def __init__( self.scaling = config["scaling"] elif config["scaling"] == "learned": self.scaling = nn.Parameter(torch.ones(1)) + elif config["scaling"] == "channel": + self.scaling = nn.Parameter(torch.ones(input_size)) else: raise ValueError("Unknown scaling type: {}".format(config["scaling"])) @@ -126,9 +129,25 @@ def __init__( nn.init.zeros_(self.adapter_up.bias) if self.use_gating: self.gate.apply(self.init_bert_weights) + elif config["init_weights"] == "houlsby": + for layer in self.adapter_down: + if isinstance(layer, nn.Linear) or isinstance(layer, PHMLayer): + nn.init.trunc_normal_(layer.weight, mean=0, std=1e-2, a=-2 * 1e-2, b=2 * 1e-2) + nn.init.zeros_(layer.bias) + + nn.init.trunc_normal_(self.adapter_up.weight, mean=0, std=1e-2, a=-2 * 1e-2, b=2 * 1e-2) + nn.init.zeros_(self.adapter_up.bias) else: raise ValueError("Unknown init_weights type: {}".format(config["init_weights"])) + if config["stochastic_depth"] > 0.0: + if is_torchvision_available(): + from torchvision.ops.stochastic_depth import StochasticDepth + + self.DropPath = StochasticDepth(p=config["stochastic_depth"], mode="row") + else: + raise ImportError("stochastic_depth requires the package torchvision, but it is not installed") + def pre_forward( self, hidden_states, @@ -176,6 +195,8 @@ def forward(self, x, residual_input, output_gating=False): down = self.adapter_down(x) up = self.adapter_up(down) + if hasattr(self, "DropPath"): + up = self.DropPath(up) up = up * self.scaling output = self.dropout(up) @@ -364,7 +385,6 @@ def __init__( self.reduction = self.T / 1000.0 def forward(self, query, key, value, residual, output_attentions: bool = False): - if self.config["residual_before"]: value += residual[:, :, None, :].repeat(1, 1, value.size(2), 1)