From 84f4f04542d0bea162f653ebe1dfe8d3db3db1da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 19 Nov 2023 19:43:50 +0100 Subject: [PATCH] Add Prompt Tuning (#595) This PR adds support for Prompt Tuning (https://aclanthology.org/2021.emnlp-main.243/) --------- Co-authored-by: calpt --- .gitignore | 2 +- docs/classes/adapter_config.rst | 7 + docs/methods.md | 28 +++ docs/model_overview.md | 44 ++-- src/adapters/__init__.py | 2 + src/adapters/configuration/adapter_config.py | 30 +++ src/adapters/context.py | 19 +- src/adapters/heads/base.py | 28 ++- src/adapters/heads/language_modeling.py | 10 + src/adapters/loading.py | 1 + src/adapters/methods/adapter_layer_base.py | 2 +- src/adapters/methods/prompt_tuning.py | 228 ++++++++++++++++++ src/adapters/model_mixin.py | 86 +++++-- src/adapters/models/albert/adapter_model.py | 5 +- src/adapters/models/albert/mixin_albert.py | 7 +- src/adapters/models/albert/modeling_albert.py | 3 + src/adapters/models/bart/adapter_model.py | 10 +- src/adapters/models/bart/mixin_bart.py | 16 +- src/adapters/models/beit/adapter_model.py | 5 +- src/adapters/models/beit/mixin_beit.py | 1 + src/adapters/models/bert/adapter_model.py | 5 +- src/adapters/models/bert/mixin_bert.py | 8 +- src/adapters/models/bert/modeling_bert.py | 3 + .../models/bert_generation/adapter_model.py | 5 +- .../modeling_bert_generation.py | 3 + src/adapters/models/clip/adapter_model.py | 7 +- src/adapters/models/clip/mixin_clip.py | 18 +- src/adapters/models/deberta/adapter_model.py | 5 +- .../models/deberta/modeling_deberta.py | 4 + .../models/deberta_v2/adapter_model.py | 5 +- .../models/deberta_v2/modeling_deberta_v2.py | 6 +- .../models/distilbert/adapter_model.py | 5 +- .../models/distilbert/mixin_distilbert.py | 15 +- .../models/distilbert/modeling_distilbert.py | 3 + src/adapters/models/electra/adapter_model.py | 5 +- .../models/electra/modeling_electra.py | 3 + .../encoder_decoder/mixin_encoder_decoder.py | 2 + src/adapters/models/gpt2/adapter_model.py | 7 +- src/adapters/models/gpt2/mixin_gpt2.py | 16 +- src/adapters/models/gptj/adapter_model.py | 7 +- src/adapters/models/gptj/mixin_gptj.py | 13 +- src/adapters/models/llama/adapter_model.py | 7 +- src/adapters/models/llama/mixin_llama.py | 16 +- src/adapters/models/mbart/adapter_model.py | 10 +- src/adapters/models/roberta/adapter_model.py | 5 +- .../models/roberta/modeling_roberta.py | 3 + src/adapters/models/t5/adapter_model.py | 10 +- src/adapters/models/t5/mixin_t5.py | 6 + src/adapters/models/t5/modeling_t5.py | 2 +- src/adapters/models/vit/adapter_model.py | 5 +- src/adapters/models/vit/mixin_vit.py | 3 + .../models/xlm_roberta/adapter_model.py | 5 +- .../xlm_roberta/modeling_xlm_roberta.py | 3 + src/adapters/models/xmod/adapter_model.py | 5 +- src/adapters/models/xmod/mixin_xmod.py | 8 +- src/adapters/models/xmod/modeling_xmod.py | 3 + src/adapters/utils.py | 45 +++- .../composition/test_adapter_composition.py | 2 +- tests_adapters/methods/__init__.py | 1 + tests_adapters/methods/test_adapter_common.py | 9 +- tests_adapters/methods/test_prompt_tuning.py | 36 +++ tests_adapters/test_adapter.py | 19 ++ tests_adapters/test_adapter_config.py | 1 + tests_adapters/test_adapter_custom_head.py | 3 +- tests_adapters/test_adapter_hub.py | 3 +- tests_adapters/test_albert.py | 2 + tests_adapters/test_beit.py | 2 + tests_adapters/test_bert.py | 2 + tests_adapters/test_bert_generation.py | 2 + tests_adapters/test_deberta.py | 2 + tests_adapters/test_debertaV2.py | 2 + tests_adapters/test_distilbert.py | 2 + tests_adapters/test_electra.py | 2 + tests_adapters/test_roberta.py | 2 + tests_adapters/test_vit.py | 2 + tests_adapters/test_xlm_roberta.py | 2 + tests_adapters/test_xmod.py | 2 + 77 files changed, 799 insertions(+), 114 deletions(-) create mode 100644 src/adapters/methods/prompt_tuning.py create mode 100644 tests_adapters/methods/test_prompt_tuning.py diff --git a/.gitignore b/.gitignore index b403149cc0..80ed06450d 100644 --- a/.gitignore +++ b/.gitignore @@ -74,7 +74,7 @@ instance/ # Sphinx documentation docs/_build/ -docs/_build/ +adapter_docs/_build/ # PyBuilder target/ diff --git a/docs/classes/adapter_config.rst b/docs/classes/adapter_config.rst index 911a3fade2..91a9f506a0 100644 --- a/docs/classes/adapter_config.rst +++ b/docs/classes/adapter_config.rst @@ -55,6 +55,13 @@ IA3Config :members: :inherited-members: Mapping +PromptTuningConfig +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.PromptTuningConfig + :members: + :inherited-members: Mapping + Combined configurations ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methods.md b/docs/methods.md index 04cebbaaac..06ad700e68 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -267,3 +267,31 @@ model.reset_adapter() _Papers:_ - [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022) + +## Prompt Tuning +Prompt Tuning is an efficient fine-tuning technique proposed by Lester et al. (2021). Prompt tuning adds tunable tokens, called soft-prompts, that are prepended to the input text. +First, the input sequence ${x_1, x_2, \dots, x_n }$ gets embedded, resulting in the matrix $X_e \in \mathbb{R}^{n \times e}$ where $e$ is the dimension of +the embedding space. The soft-prompts with length $p$ are represented as $P_e \in \mathbb{R}^{p \times e}$. +$P_e$ and $X_e$ get concatenated, forming the input of the following encoder or decoder: + +$$ +\left[P_e; X_e\right] \in \mathbb{R}^{\left(p + n\right) \times e} +$$ + +The `PromptTuningConfig` has the properties: +- `prompt_length`: to set the soft-prompts length $p$ +- `prompt_init`: to set the weight initialisation method, which is either "random_uniform" or "from_string" to initialize each prompt token with an embedding drawn from the model’s vocabulary. + - `prompt_init_text` as the text use for initialisation if `prompt_init="from_string"` +- `combine`: To define if the prefix should be added before the embedded input sequence or after the BOS token + +To add Prompt Tuning to your model, you can use the predefined configs: +```python +from adapters import PromptTuningConfig + +config = PromptTuningConfig(prompt_length=10) +model.add_adapter("dummy", config=config) +``` + +_Papers:_ +- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021) + diff --git a/docs/model_overview.md b/docs/model_overview.md index 8198ea64d0..a5ba7c4e8c 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -10,28 +10,28 @@ The table below further shows which model architectures support which adaptation E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters. ``` -| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | -| --------------------------------------- | -| - | - | - | - | - | - | -| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | -| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | -| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | +| --------------------------------------- | -| - | - | - | - | - | - |- | +| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | +| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | +| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | (*) If the used encoder and decoder model class are supported. diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a02be1d420..2b1524734d 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -54,6 +54,7 @@ "ModelAdaptersConfig", "ParBnConfig", "PrefixTuningConfig", + "PromptTuningConfig", "SeqBnConfig", "SeqBnInvConfig", "StaticAdapterFusionConfig", @@ -161,6 +162,7 @@ ModelAdaptersConfig, ParBnConfig, PrefixTuningConfig, + PromptTuningConfig, SeqBnConfig, SeqBnInvConfig, StaticAdapterFusionConfig, diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 22b6c8101f..63039a8459 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -84,6 +84,8 @@ def _get_config_class(config_dict): cls_new = LoRAConfig elif architecture == "union": cls_new = ConfigUnion + elif architecture == "prompt_tuning": + cls_new = PromptTuningConfig else: cls_new = BnConfig @@ -395,6 +397,33 @@ class PrefixTuningConfig(AdapterConfig): shared_gating: bool = True +@dataclass(eq=False) +class PromptTuningConfig(AdapterConfig): + """ + The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf + + Args: + prompt_length (int): The number of tokens in the prompt. + Defaults to 10. + prompt_init (str): The initialization method for the prompt. Can be either "random_uniform" or "from_string". + Defaults to "random_uniform". + prompt_init_text (str): The text to use for prompt initialization if prompt_init="from_string". + random_uniform_scale (float): The scale of the random uniform initialization if prompt_init="random_uniform". + Defaults to 0.5 as in the paper. + combine (str): + The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos". + Defaults to "prefix". + """ + + architecture: str = "prompt_tuning" + + prompt_length: int = 10 + prompt_init: str = "random_uniform" + prompt_init_text: Optional[str] = None + random_uniform_scale = 0.5 + combine: str = "prefix" + + @dataclass(eq=False) class LoRAConfig(AdapterConfig): """ @@ -612,6 +641,7 @@ def __init__( "compacter": CompacterConfig(), "prefix_tuning": PrefixTuningConfig(), "prefix_tuning_flat": PrefixTuningConfig(flat=True), + "prompt_tuning": PromptTuningConfig(), "lora": LoRAConfig(), "ia3": IA3Config(), "mam": MAMConfig(), diff --git a/src/adapters/context.py b/src/adapters/context.py index 784ed579ea..70e685d037 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -78,7 +78,13 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() - context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"] + context_attributes = [ + "adapter_gating_scores", + "adapter_fusion_attentions", + "adapter_input_parallelized", + ] + # Additional used attributes not exposed to the user + # - prompt_tokens_length: length of the prompt tokens def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. @@ -102,6 +108,8 @@ def wrap(cls, f): def wrapper_func(self, *args, **kwargs): if self.adapters_config is not None: with cls(self, *args, **kwargs) as ctx: + # whether to output the context attributes + output_context = kwargs.pop("output_context", False) kwargs = { k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes } @@ -116,7 +124,14 @@ def wrapper_func(self, *args, **kwargs): for attr in cls.context_attributes: if getattr(ctx, "output_" + attr, False): results[attr] = dict(getattr(ctx, attr)) - return results + + if output_context: + context_dict = ctx.__dict__ + + if output_context: + return results, context_dict + else: + return results else: return f(self, *args, **kwargs) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index b3116a7b9a..d45897df10 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -326,6 +326,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal labels = kwargs.pop("labels", None) if labels is not None: loss_fct = CrossEntropyLoss() + # adjust labels for prompt tuning + if kwargs.get("prompt_tokens_length", 0) > 0: + prompt_length = kwargs.get("prompt_tokens_length") + prompt_labels = torch.full( + (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device + ) + labels = torch.cat((prompt_labels, labels), dim=-1) + if attention_mask is not None: + attention_mask = torch.cat( + (torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask), + dim=-1, + ) + # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 @@ -763,7 +776,14 @@ def _get_used_heads(self, head_name: str = None): return head_modules def forward_head( - self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs + self, + all_outputs, + head_name=None, + cls_output=None, + attention_mask=None, + return_dict=False, + context=None, + **kwargs ): """ The forward pass through a prediction head configuration. There are three ways to specify the used prediction @@ -811,6 +831,12 @@ def _get_head_input(outputs, cls_out, batch): if inv_adapter: kwargs["invertible_adapter"] = inv_adapter + # Set prompt tokens length + if context is not None: + prompt_tokens_length = context.get("prompt_tokens_length", None) + if prompt_tokens_length is not None: + kwargs["prompt_tokens_length"] = prompt_tokens_length + if isinstance(self.active_head, BatchSplit): if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]: raise ValueError( diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py index 3e0cda610a..bf91e5be08 100644 --- a/src/adapters/heads/language_modeling.py +++ b/src/adapters/heads/language_modeling.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput @@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal labels = labels[..., 1:].contiguous() else: logits_for_loss = lm_logits + + # adjust labels for prompt tuning + if kwargs.get("prompt_tokens_length", 0) > 0: + prompt_length = kwargs.get("prompt_tokens_length") + prompt_labels = torch.full( + (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device + ) + labels = torch.cat((prompt_labels, labels), dim=-1) + loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1)) if return_dict: diff --git a/src/adapters/loading.py b/src/adapters/loading.py index f189a120a8..8e22cfc128 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -307,6 +307,7 @@ def filter_func(self, adapter_name): or ".prefix_tunings.{}.".format(adapter_name) in x or ".prefix_gates.{}.".format(adapter_name) in x or ".loras.{}.".format(adapter_name) in x + or ".prompt_tunings.{}.".format(adapter_name) in x ) # This dict maps the original weight names to the currently used equivalents. diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index 79d18500ec..2489d445b2 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -346,7 +346,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: # sequentially feed different parts of the blown-up batch into different adapters children_states = [] for i, child in enumerate(adapter_setup): - # compute ids of sequences thet should be passed to the ith adapter + # compute ids of sequences that should be passed to the ith adapter batch_idx = ( sum(adapter_setup.batch_sizes[:i]), sum(adapter_setup.batch_sizes[: i + 1]), diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py new file mode 100644 index 0000000000..fd9db60e0a --- /dev/null +++ b/src/adapters/methods/prompt_tuning.py @@ -0,0 +1,228 @@ +# https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py + +import math +from typing import Callable, Dict, List, Union + +import numpy as np +import torch +from torch import nn + +from transformers import AutoTokenizer +from transformers.configuration_utils import PretrainedConfig + +from ..composition import AdapterCompositionBlock +from ..configuration import ModelAdaptersConfig, PromptTuningConfig +from ..context import ForwardContext +from .adapter_layer_base import AdapterLayerBase + + +class PromptTuning(nn.Module): + """Generate a Prompt and concatenate it with the input. + + This is the training time version of prompting a model. Calling the injected `prompt` module will generate your + unbatched prompt. This model then replicates it for the batched input and concatenates them together. + + Attributes: + prompt: The module that actually generates the unbatched prompt. + combine: A function that combines the prompt and the embedded input. + """ + + prompt: nn.Module + combination_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + def __init__( + self, + adapter_name: str, + prompt_tuning_config: PromptTuningConfig, + model_config: PretrainedConfig, + base_model_embeddings: nn.Module, + ): + super().__init__() + + self.name = adapter_name + self.model_config = model_config + self.prompt_tuning_config = prompt_tuning_config + + embedding_size = getattr(model_config, "embedding_size", model_config.hidden_size) + + self.prompt_embedding = nn.Embedding( + num_embeddings=prompt_tuning_config.prompt_length, embedding_dim=embedding_size + ) + # Initialize prompt tokens + self.prompt_tokens = torch.arange(prompt_tuning_config.prompt_length).long() + + self._init_prompt_embedding(base_model_embeddings) + + if prompt_tuning_config.combine == "prefix": + self.combination_fn = lambda prompt, embedded_input: torch.cat([prompt, embedded_input], dim=1) + elif prompt_tuning_config.combine == "prefix_after_bos": + self.combination_fn = lambda prompt, embedded_input: torch.cat( + [embedded_input[:, 0, np.newaxis], prompt, embedded_input[:, 1:]], dim=1 + ) + else: + raise ValueError( + f"Unknown combination function: {prompt_tuning_config.combine}. " + "Must be one of 'prefix' or 'prefix_after_bos'." + ) + + def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None: + if self.prompt_tuning_config.prompt_init == "random_uniform": + nn.init.uniform_( + self.prompt_embedding.weight, + a=-self.prompt_tuning_config.random_uniform_scale, + b=self.prompt_tuning_config.random_uniform_scale, + ) + + elif self.prompt_tuning_config.prompt_init == "from_string": + tokenizer = AutoTokenizer.from_pretrained(self.model_config.tokenizer_name_or_path) + prompt_length = self.prompt_tuning_config.prompt_length + prompt_text = self.prompt_tuning_config.prompt_init_text + if prompt_text is None: + raise ValueError("Prompt text must be provided when using prompt_init='from_string'.") + + tokenized_prompt_text: list[int] = tokenizer(prompt_text)["input_ids"] # type: ignore + + # If the prompt text tokens are shorter than the prompt length, we repeat the prompt text tokens until we reach the prompt length + if len(tokenized_prompt_text) < prompt_length: + num_reps = math.ceil(prompt_length / len(tokenized_prompt_text)) + tokenized_prompt_text = tokenized_prompt_text * num_reps + + # Adjust length of prompt text tokens to match prompt_length + tokenized_prompt_text = tokenized_prompt_text[:prompt_length] + + # Initialize prompt embedding with tokenized prompt text + word_embedding_weights = base_model_embeddings(torch.LongTensor(tokenized_prompt_text)).detach().clone() + word_embedding_weights = word_embedding_weights.to(torch.float32) + self.prompt_embedding.weight = nn.Parameter(word_embedding_weights) + + else: + raise ValueError(f"Unknown prompt initialization: {self.prompt_tuning_config.prompt_init}") + + def forward(self, embedded_input): + # Compute prompt embedding + self.prompt_tokens = self.prompt_tokens.to(embedded_input.device) + prompt = self.prompt_embedding(self.prompt_tokens) + + # Prompt to batch size + batch_size = embedded_input.shape[0] + prompt = torch.tile(torch.unsqueeze(prompt, dim=0), [batch_size] + [1 for _ in prompt.shape]) + + # Merge prompt and input + output = self.combination_fn(prompt, embedded_input) + + # Adapt attention mask + prefix_attention_mask_length = self.prompt_tuning_config.prompt_length + + return output, prefix_attention_mask_length + + +class PromptTuningLayer(AdapterLayerBase, nn.Module): + """ + Prompt Tuning implementation. + + Args: + model_config: The model configuration. + adapters_config: The adapter configuration. + base_model_embeddings: + The embedding layer of the base model (used to initialize the prompt embedding if + prompt_init='from_string'). + """ + + adapter_modules_name = "prompt_tunings" + + def __init__( + self, + model_config: PretrainedConfig, + adapters_config: ModelAdaptersConfig, + base_model_embeddings: nn.Module, + ): + super().__init__() + self.model_config = model_config + self.adapters_config = adapters_config + self.base_model_embeddings = base_model_embeddings + self.prompt_tunings = nn.ModuleDict() + + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + # ignore layer_idx as prompt tunings are only added after the embedding layer + prompt_tuning_config = self.adapters_config.match( + adapter_name, + config_type=PromptTuningConfig, + ) + + if prompt_tuning_config is not None: + adapter = PromptTuning( + adapter_name=adapter_name, + prompt_tuning_config=prompt_tuning_config, # type: ignore + model_config=self.model_config, + base_model_embeddings=self.base_model_embeddings, + ) + adapter.train(self.training) # make sure training mode is consistent + self.prompt_tunings[adapter_name] = adapter + return True + + return False + + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + # add new adapter + if self.add_adapter(adapter_name, -1): + # average weights + avg_state_dict = {} + for name, weight in input_adapters.items(): + if name in self.prompt_tunings: + module = self.prompt_tunings[name] + for k, v in module.state_dict().items(): + if k in avg_state_dict: + avg_state_dict[k] += weight * v + else: + avg_state_dict[k] = weight * v + else: + self.delete_adapter(adapter_name) # clean up before raising error + raise ValueError("Adapter {} not found.".format(name)) + # load averaged weights + self.prompt_tunings[adapter_name].load_state_dict(avg_state_dict) + return True + + return False + + def delete_adapter(self, adapter_name: str): + if adapter_name in self.prompt_tunings: + del self.prompt_tunings[adapter_name] + + def add_fusion_layer(self, adapter_names: Union[List, str]): + pass # not applicable to prompt tuning + + def delete_fusion_layer(self, adapter_names: Union[List, str]): + pass # not applicable to prompt tuning + + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + if unfreeze_adapters: + for prompt_tuning_name in adapter_setup.flatten(): + if prompt_tuning_name in self.prompt_tunings: + for param in self.prompt_tunings[prompt_tuning_name].parameters(): + param.requires_grad = True + + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + if adapter_name in self.prompt_tunings: + self.prompt_tunings[adapter_name].train(not freeze) + for param in self.prompt_tunings[adapter_name].parameters(): + param.requires_grad = not freeze + + def get_adapter(self, adapter_name): + if adapter_name in self.prompt_tunings: + return self.prompt_tunings[adapter_name] + else: + return None + + def forward(self, hidden_states: torch.Tensor): + prefix_attention_mask_length = None + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self.prompt_tunings: + hidden_states, prefix_attention_mask_length = self.prompt_tunings[first_adapter](hidden_states) + + context = ForwardContext.get_context() + if context is not None: + context.prompt_tokens_length = prefix_attention_mask_length + + return hidden_states diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 078086ece8..a4a1b8d17c 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from os.path import join -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -22,6 +22,7 @@ from .methods.lora import LoRALayer from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool +from .methods.prompt_tuning import PromptTuningLayer from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -40,22 +41,6 @@ def init_adapters(self, model_config, adapters_config, **kwargs): if hasattr(super(), "init_adapters"): super().init_adapters(self.config, self.adapters_config, **kwargs) - self.hook_after_embeddings(self._hook_fn) - - def _hook_fn(self, module, args, output): - new_output = self.invertible_adapters_forward(output) - return new_output - - def hook_after_embeddings(self, hook_fn: Callable): - """ - Hook a function to be called after the embeddings have been computed. The default implementation does nothing. - Override this method to add a hook. - - Args: - hook_fn (Callable): The function to be called after the embeddings have been computed. - """ - pass - def add_invertible_adapter(self, adapter_name: str) -> bool: """ Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an @@ -132,13 +117,31 @@ def enable_invertible_adapters(self, adapter_names): def invertible_adapters_forward(self, hidden_states, rev=False): # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if self.adapters_config.active_setup is not None and len(self.adapters_config.active_setup) > 0: - first_adapter = self.adapters_config.active_setup.first() + adapter_setup = self._get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() if first_adapter in self.invertible_adapters: hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev) - return hidden_states + def _get_active_setup(self): + if hasattr(self, "adapters_config"): + # First check current context before falling back to defined setup + context = AdapterSetup.get_context() + if context is not None: + adapter_setup = context.adapter_setup + else: + adapter_setup = self.adapters_config.active_setup + else: + adapter_setup = None + skip_adapters = adapter_setup is None or ( + self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers + ) + if not skip_adapters and (len(adapter_setup.flatten()) > 0): + return adapter_setup + else: + return None + class InvertibleAdaptersWrapperMixin: """ @@ -364,6 +367,10 @@ def loaded_embeddings(self): class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" + add_base_adapters = False + support_prompt_tuning = True # If False, the prompt tuning layer is not added to the model. If True, the prompt tuning layer is added if add_base_adapters is True. + _tied_weights_keys = ["prompt_tuning.base_model_embeddings.*"] + def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -400,6 +407,11 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr self.base_model.prefix_tuning = PrefixTuningPool(self.config, self.adapters_config) self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer)) + # Add Prompt Tuning + if self.add_base_adapters: + if self.support_prompt_tuning: + self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings()) + # Initialize adapters from config for adapter_name in self.adapters_config: self._add_adapter_weights(adapter_name) @@ -430,12 +442,23 @@ def apply_to_adapter_layers(self, fn): if isinstance(module, AdapterLayerBase): fn(i, module) + def apply_to_basemodel_childs(self, fn): + """ + Applies a function to all direct childs of the model if they are a instance of AdapterLayerBase. + """ + if self.add_base_adapters: + for module in self.base_model.children(): + if isinstance(module, AdapterLayerBase): + # These childs don't have a layer index so we pass -1 + fn(-1, module) + def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): """Sets the model into mode for training the given adapters.""" self.train() self.freeze_model(True) adapter_setup = parse_composition(adapter_setup) self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False)) + self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, True, False)) for adapter_name in adapter_setup: if adapter_name in self.base_model.shared_parameters: for param in self.base_model.shared_parameters[adapter_name].values(): @@ -462,6 +485,7 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc self.freeze_model(True) adapter_setup = parse_composition(adapter_setup) self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, unfreeze_adapters, True)) + self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, unfreeze_adapters, True)) # use the adapters to be trained by default in every forward pass self.set_active_adapters(adapter_setup) # TODO implement fusion for invertible adapters @@ -545,6 +569,8 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False def _add_adapter_weights(self, adapter_name: str): """Helper method that performs the actual parameter additions when adding a new adapter.""" self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) + self.apply_to_basemodel_childs(lambda i, child: child.add_adapter(adapter_name, i)) + # PHM Layer if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"): adapter_module = list(self.get_adapter(adapter_name)[0].values())[0] @@ -624,6 +650,7 @@ def add_adapter_fusion( self.delete_adapter_fusion(adapter_names) self.adapters_config.add_fusion(adapter_names, config=config) self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) + self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names)) if set_active: if not isinstance(adapter_names, list): adapter_names = adapter_names.split(",") @@ -641,11 +668,13 @@ def delete_adapter(self, adapter_name: str): return del self.adapters_config.adapters[adapter_name] self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name)) + self.apply_to_basemodel_childs(lambda i, child: child.delete_adapter(adapter_name)) # PHM Layer if adapter_name in self.base_model.shared_parameters: del self.base_model.shared_parameters[adapter_name] if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): self.delete_invertible_adapter(adapter_name) + # Reset active adapters if this was the only active adapter if self.active_adapters == Stack(adapter_name): self.active_adapters = None @@ -671,6 +700,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): return del self.adapters_config.fusions[adapter_fusion_name] self.apply_to_adapter_layers(lambda i, layer: layer.delete_fusion_layer(adapter_fusion_name)) + self.apply_to_basemodel_childs(lambda i, child: child.delete_fusion_layer(adapter_fusion_name)) # Reset active adapters if this was the active setup if self.active_adapters == adapter_names: self.active_adapters = None @@ -966,6 +996,11 @@ def get_adapter(self, name) -> dict: ) and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] + if self.support_prompt_tuning: + prompt_tuning = self.prompt_tuning.get_adapter(name) + if prompt_tuning is not None: + destination[-1]["prompt"] = prompt_tuning + # use a custom index to ensure numbering is from 0 to N layers for i, (_, layer) in enumerate(self.iter_layers()): for module in layer.modules(): @@ -1115,6 +1150,7 @@ def average_adapter( input_adapters = {name: weight / sum_weights for name, weight in zip(adapter_list, weights)} try: self.apply_to_adapter_layers(lambda i, layer: layer.average_adapter(adapter_name, input_adapters)) + self.apply_to_basemodel_childs(lambda i, child: child.average_adapter(adapter_name, input_adapters)) # PHM Layer if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"): self._average_shared_parameters(adapter_name, input_adapters) @@ -1226,6 +1262,16 @@ def save_pretrained( @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): + add_base_adapters = True + + def post_embedding_forward(self, module, args, embedding_output): + if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): + embedding_output = self.invertible_adapters_forward(embedding_output) + + embedding_output = self.prompt_tuning.forward(embedding_output) + + return embedding_output + @ForwardContext.wrap def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py index 9a1f45ed2e..8261e68760 100644 --- a/src/adapters/models/albert/adapter_model.py +++ b/src/adapters/models/albert/adapter_model.py @@ -64,7 +64,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.albert( + outputs, context = self.albert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -77,7 +77,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index ff9ef19fe3..4d84abd5f6 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -51,6 +51,8 @@ def init_adapters(self, model_config, adapters_config): for _, layer in self.iter_layers(): self._set_layer_hook_for_parallel(layer) + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): adjust_tensors_for_parallel_(input[0], input[1]) @@ -64,6 +66,3 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for albertLayer in albertLayerGroup.albert_layers: yield i, albertLayer i += 1 - - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py index 7f5294cad7..f620240c63 100644 --- a/src/adapters/models/albert/modeling_albert.py +++ b/src/adapters/models/albert/modeling_albert.py @@ -24,6 +24,7 @@ from transformers.pytorch_utils import apply_chunking_to_forward from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin @@ -35,6 +36,8 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 3fc3dfd73c..ddb94e6fe9 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -26,7 +26,10 @@ "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) @@ -76,7 +79,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs = self.model( + outputs, context = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -95,7 +98,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # sequence classification based on last token in sequence x = outputs[0] # last hidden state if input_ids is not None and x.shape[1] == input_ids.shape[1]: diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index 28e7b3ac77..d269d72b43 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -34,6 +34,7 @@ class BartEncoderLayerAdaptersMixin: """Adds adapters to the BartEncoderLayer module of BART.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config # Wrap layers for LoRA self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) @@ -58,8 +59,7 @@ def init_adapters(self, model_config, adapters_config): class BartEncoderAdaptersMixin(InvertibleAdaptersMixin): """Adds adapters to the BartEncoder module of BART.""" - def hook_after_embeddings(self, hook_fn: Callable): - return self.layernorm_embedding.register_forward_hook(hook_fn) + pass class BartDecoderAdaptersMixin: @@ -76,6 +76,11 @@ class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMi """Adds adapters to the BartModel class.""" invertible_adapters_base_name = "encoder" + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.encoder.layernorm_embedding.register_forward_hook(self.post_embedding_forward) def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: if hasattr(self, "encoder"): @@ -87,6 +92,11 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.decoder.layers): yield i, layer + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + class BartDecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): """Adds adapters to the BartDecoderWrapper class.""" diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 22d3e6aa4d..ceeda7b82c 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -47,7 +47,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.beit( + outputs, context = self.beit( pixel_values, bool_masked_pos=bool_masked_pos, head_mask=head_mask, @@ -57,7 +57,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index 536048e669..608d9c5cc4 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -47,6 +47,7 @@ class BeitModelAdaptersMixin(ModelBaseAdaptersMixin): def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + self.embeddings.register_forward_hook(self.post_embedding_forward) def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py index 4ff1aaf61a..02ad9411c4 100644 --- a/src/adapters/models/bert/adapter_model.py +++ b/src/adapters/models/bert/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( + outputs, context = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index 3cf5a6e1ff..4aa079240c 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -77,6 +77,9 @@ def init_adapters(self, model_config, adapters_config): for _, layer in self.iter_layers(): self._set_layer_hook_for_parallel(layer) + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): adjust_tensors_for_parallel_(input[0], input[1]) @@ -87,6 +90,3 @@ def hook(module, input): def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer - - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index 692605610a..ea60b6f5dc 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -26,6 +26,7 @@ from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,6 +41,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py index c251af1517..1fe0152a6a 100644 --- a/src/adapters/models/bert_generation/adapter_model.py +++ b/src/adapters/models/bert_generation/adapter_model.py @@ -62,7 +62,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( + outputs, context = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -78,7 +78,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index 8381ccf2bb..f0ef9a35ed 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -28,6 +28,7 @@ ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -52,6 +53,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/clip/adapter_model.py b/src/adapters/models/clip/adapter_model.py index 6191cd3001..5aa15d417b 100644 --- a/src/adapters/models/clip/adapter_model.py +++ b/src/adapters/models/clip/adapter_model.py @@ -18,6 +18,8 @@ @add_start_docstrings(CLIP_START_DOCSTRING) class CLIPAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, CLIPPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since CLIP does not yet support prompt tuning + def __init__(self, config): super().__init__(config) @@ -44,7 +46,7 @@ def forward( output_adapter_fusion_attentions=False, **kwargs ): - outputs = self.clip( + outputs, context = self.clip( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, @@ -56,7 +58,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context if head or AdapterSetup.get_context_head_setup() or self.active_head: head_outputs = self.forward_head( diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 02469974f5..020d88f57b 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -60,14 +60,23 @@ def hook(module, input): class CLIPTextTransformerAdaptersMixin(InvertibleAdaptersMixin): """Adds adapters to the CLIPTextTransformer module of CLIP.""" - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output class CLIPTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): """Adds adapters to the CLIPTextModel class.""" invertible_adapters_base_name = "text_model" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.text_model.encoder.layers): @@ -77,6 +86,8 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: class CLIPVisionModelAdaptersMixin(ModelBaseAdaptersMixin): """Adds adapters to the a CLIPVisionModel class.""" + support_prompt_tuning = False + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.vision_model.encoder.layers): yield i, layer @@ -86,6 +97,7 @@ class CLIPModelAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWr """Adds adapters to the CLIPModel class.""" invertible_adapters_base_name = "text_model" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.text_model.encoder.layers): diff --git a/src/adapters/models/deberta/adapter_model.py b/src/adapters/models/deberta/adapter_model.py index 0d74e58593..4b44991d66 100644 --- a/src/adapters/models/deberta/adapter_model.py +++ b/src/adapters/models/deberta/adapter_model.py @@ -57,7 +57,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.deberta( + outputs, context = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -69,7 +69,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 71b7f9dc2a..1feca72b4a 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -25,6 +25,7 @@ ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta import DebertaSelfAttentionAdaptersMixin @@ -94,6 +95,9 @@ def forward( """ + attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) diff --git a/src/adapters/models/deberta_v2/adapter_model.py b/src/adapters/models/deberta_v2/adapter_model.py index bc2f6e6ed2..a980d99177 100644 --- a/src/adapters/models/deberta_v2/adapter_model.py +++ b/src/adapters/models/deberta_v2/adapter_model.py @@ -60,7 +60,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.deberta( + outputs, context = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -72,7 +72,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index aa8945000f..56d6fec448 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -25,6 +25,7 @@ ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin @@ -88,9 +89,10 @@ def forward( rel_embeddings (`torch.FloatTensor`): The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. - - """ + attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + if query_states is None: query_states = hidden_states query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads) diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py index 9a2294ac89..ee7fb57bf9 100644 --- a/src/adapters/models/distilbert/adapter_model.py +++ b/src/adapters/models/distilbert/adapter_model.py @@ -85,7 +85,7 @@ def forward( else None ) - distilbert_output = self.distilbert( + distilbert_output, context = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, @@ -96,7 +96,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context outputs = self.forward_head( distilbert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index 111733c2f0..3301cba949 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -44,15 +44,10 @@ def forward(self, *args, **kwargs): class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): """Adds adapters to the DistilBert module.""" + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.embeddings.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.transformer.layer): yield i, layer - - def _hook_fn(self, module, input): - new_input = self.invertible_adapters_forward(input) - return new_input - - def hook_after_embeddings(self, hook_fn: Callable): - # PyTorch's built-in pre-forward hook does not pass the input ids. - # Therefore, we need to use a custom hook. - self.transformer.pre_forward_fn = hook_fn diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index e0aee4e1b9..cbd501942c 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -28,6 +28,7 @@ from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin @@ -121,6 +122,8 @@ def forward( torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ adjust_tensors_for_parallel_(x, attn_mask) + attn_mask = prefix_attention_mask(attn_mask, dim=1, prefix_value=1) # type: ignore + # Self-Attention sa_output = self.attention( query=x, diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py index 2d7994d3a4..6dbd02569f 100644 --- a/src/adapters/models/electra/adapter_model.py +++ b/src/adapters/models/electra/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.electra( + outputs, context = self.electra( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context head_inputs = outputs diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index cbe4277ec9..53b5ed29ac 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -7,6 +7,7 @@ from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -21,6 +22,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py index ba2df24a84..50257d1536 100644 --- a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py +++ b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py @@ -17,6 +17,8 @@ class EncoderDecoderModelAdaptersMixin( ): """Adds adapters to the EncoderDecoderModel class.""" + support_prompt_tuning = False + def init_adapters(self, model_config, adapters_config): if not isinstance(self.encoder, ModelAdaptersMixin) or not isinstance(self.decoder, ModelAdaptersMixin): return diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index b4e1b53d54..cc5709b53e 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -33,6 +33,8 @@ GPT2_START_DOCSTRING, ) class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): + _tied_weights_keys = [] # needs to be empty since GPT2 does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) @@ -68,7 +70,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs, context = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -85,7 +87,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index ce88136a92..bd142eb470 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -54,9 +54,19 @@ def init_adapters(self, model_config, adapters_config): class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.drop.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.drop.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index 625cd0febc..a4bd8f32a1 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -33,6 +33,8 @@ GPTJ_START_DOCSTRING, ) class GPTJAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since GPT-J does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.transformer = GPTJModel(config) @@ -66,7 +68,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs, context = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -81,7 +83,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index 7e4e771cba..cc20e63cc3 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -38,12 +38,19 @@ def init_adapters(self, model_config, adapters_config): class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + # Register hook for post embedding forward + self.drop.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.drop.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 43cec8abbf..7b9ce69083 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -32,6 +32,8 @@ LLAMA_START_DOCSTRING, ) class LlamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since LLaMA does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) @@ -68,7 +70,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs, context = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -81,7 +83,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 3caf66e544..aae339433c 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -28,9 +28,19 @@ def init_adapters(self, model_config, adapters_config): class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embed_tokens.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.layers): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.embed_tokens.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index ae86c35c17..5b57eb2cb0 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -26,7 +26,10 @@ "MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING ) class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) @@ -76,7 +79,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs = self.model( + outputs, context = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -95,7 +98,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # sequence classification based on last token in sequence x = outputs[0] # last hidden state if input_ids is not None and x.shape[1] == input_ids.shape[1]: diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py index 13bf8b8102..3a08f33639 100644 --- a/src/adapters/models/roberta/adapter_model.py +++ b/src/adapters/models/roberta/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index e33b7e7ca3..8a79d4effb 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -25,6 +25,7 @@ from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,6 +41,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 66441727c7..d981815bd9 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -22,7 +22,10 @@ @add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING) class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", @@ -82,7 +85,7 @@ def forward( # decoder_input_ids from input_ids if no decoder_input_ids are provided decoder_input_ids = self._shift_right(input_ids) - model_output = self.transformer( + model_output, context = self.transformer( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -101,7 +104,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context sequence_output = model_output[0] # ToDo move head to device for parallel forward pass diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index 244f5d4335..1aa6227d88 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -83,11 +83,17 @@ def init_adapters(self, model_config, adapters_config): if not self.is_decoder: InvertibleAdaptersMixin.init_adapters(self, self.config, adapters_config) + def post_embedding_forward(self, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + class T5ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): """Adds adapters to the T5Model class.""" invertible_adapters_base_name = "encoder" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: global_i = 0 diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 19064f58b2..b366b9cebe 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -356,7 +356,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) if not self.is_decoder: - hidden_states = self.invertible_adapters_forward(hidden_states) + hidden_states = self.post_embedding_forward(hidden_states) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] diff --git a/src/adapters/models/vit/adapter_model.py b/src/adapters/models/vit/adapter_model.py index 33eaaf2ea0..254a5ab0d7 100644 --- a/src/adapters/models/vit/adapter_model.py +++ b/src/adapters/models/vit/adapter_model.py @@ -47,7 +47,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vit( + outputs, context = self.vit( pixel_values, head_mask=head_mask, output_attentions=output_attentions, @@ -57,7 +57,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py index 2f9962a9d8..9b4a92e45b 100644 --- a/src/adapters/models/vit/mixin_vit.py +++ b/src/adapters/models/vit/mixin_vit.py @@ -52,6 +52,9 @@ class ViTModelAdaptersMixin(ModelBaseAdaptersMixin): def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py index ab1ca81f79..33963d5f1e 100644 --- a/src/adapters/models/xlm_roberta/adapter_model.py +++ b/src/adapters/models/xlm_roberta/adapter_model.py @@ -68,7 +68,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -81,7 +81,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index 5f18c9f70e..959f75c0e7 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -29,6 +29,7 @@ ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -44,6 +45,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py index 31ca7acd3b..d61578f158 100644 --- a/src/adapters/models/xmod/adapter_model.py +++ b/src/adapters/models/xmod/adapter_model.py @@ -73,7 +73,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, lang_ids=lang_ids, attention_mask=attention_mask, @@ -87,7 +87,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/xmod/mixin_xmod.py b/src/adapters/models/xmod/mixin_xmod.py index eac7e4b418..bef4371f34 100644 --- a/src/adapters/models/xmod/mixin_xmod.py +++ b/src/adapters/models/xmod/mixin_xmod.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -25,6 +25,9 @@ def init_adapters(self, model_config, adapters_config): for _, layer in self.iter_layers(): del layer.output.adapter_modules + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): # hook[1] is lang_ids tensor @@ -37,9 +40,6 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) - def forward(self, *args, **kwargs): if "lang_ids" in kwargs and kwargs["lang_ids"] is not None: raise ValueError( diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index 4a2269fbae..e91131e90e 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -24,6 +24,7 @@ from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -39,6 +40,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/utils.py b/src/adapters/utils.py index cca9ee0e6b..0e3b20cabe 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -21,6 +21,8 @@ from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile +import torch + import requests from filelock import FileLock from huggingface_hub import HfApi, HfFolder, snapshot_download @@ -36,6 +38,7 @@ from transformers.utils.hub import torch_cache_home from . import __version__ +from .context import ForwardContext logger = logging.getLogger(__name__) @@ -287,7 +290,6 @@ def get_from_cache( # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): - # If the download just completed while the lock was activated. if os.path.exists(cache_path) and not force_download: # Even if returning early like here, the lock will be released. @@ -819,3 +821,44 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf return None else: raise ValueError("Please specify either 'ah' or 'hf' as source.") + + +def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): + """ + Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length` + attribute in the ForwardContext. + + Args: + attention_mask: + The attention mask to add the prefix to. + dim (int): + The dimension along which to concatenate the prefix_attention_mask. Defaults to 3. + prefix_value (int): + The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use + different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not + masked tokens. This inversion is usually done in the forward method of the model in 2 different ways: + 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT + does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` + """ + + forward_context = ForwardContext.get_context() + + if ( + attention_mask is not None + and forward_context is not None + and getattr(forward_context, "prompt_tokens_length", None) is not None + ): + # Create a tensor of ones with the desired shape + ones_shape = list(attention_mask.shape) + ones_shape[dim] = forward_context.prompt_tokens_length + + prefix_attention_mask = torch.full( + ones_shape, + prefix_value, + dtype=attention_mask.dtype, + ).to(attention_mask.device) + + # Concatenate the prefix_attention_mask along the specified dimension + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) + + return attention_mask diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index 42e1f64c1b..e1922dc5d7 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -5,7 +5,7 @@ import adapters from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition -from tests.test_modeling_common import ids_tensor +from tests_adapters.test_adapter import ids_tensor from transformers import BertConfig, BertForSequenceClassification from transformers.testing_utils import require_torch, torch_device diff --git a/tests_adapters/methods/__init__.py b/tests_adapters/methods/__init__.py index f40a688e58..b1cbe52de4 100644 --- a/tests_adapters/methods/__init__.py +++ b/tests_adapters/methods/__init__.py @@ -22,4 +22,5 @@ from .test_ia3 import IA3TestMixin from .test_lora import LoRATestMixin from .test_prefix_tuning import PrefixTuningTestMixin +from .test_prompt_tuning import PromptTuningTestMixin from .test_unipelt import UniPELTTestMixin diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py index 616e6a99e8..5c543dadca 100644 --- a/tests_adapters/methods/test_adapter_common.py +++ b/tests_adapters/methods/test_adapter_common.py @@ -27,7 +27,6 @@ @require_torch class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin): - adapter_configs_to_test = [ (SeqBnConfig(), ["adapters.{name}."]), (MAMConfig(), ["adapters.{name}.", "prefix_tunings.{name}."]), @@ -211,6 +210,14 @@ def test_adapter_forward(self): with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): self.run_forward_test(model, adapter_config) + def test_invertible_adapter_forward(self): + model = self.get_model() + model.eval() + + for adapter_config, _ in self.inv_adapter_configs_to_test: + with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): + self.run_forward_test(model, adapter_config) + def test_load_adapter(self): self.run_load_test(SeqBnConfig()) diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py new file mode 100644 index 0000000000..d0b12d259c --- /dev/null +++ b/tests_adapters/methods/test_prompt_tuning.py @@ -0,0 +1,36 @@ +from adapters import PromptTuningConfig +from transformers.testing_utils import require_torch + +from .base import AdapterMethodBaseTestMixin + + +@require_torch +class PromptTuningTestMixin(AdapterMethodBaseTestMixin): + def test_add_prompt_tuning(self): + model = self.get_model() + self.run_add_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_average_prompt_tuning(self): + model = self.get_model() + self.run_average_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_delete_prompt_tuning(self): + model = self.get_model() + self.run_delete_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_get_prompt_tuning(self): + model = self.get_model() + self.run_get_test(model, PromptTuningConfig(prompt_length=10), 1) + + def test_forward_prompt_tuning(self): + model = self.get_model() + self.run_forward_test(model, PromptTuningConfig(prompt_length=10)) + + def test_load_prompt_tuning(self): + self.run_load_test(PromptTuningConfig(prompt_length=10)) + + def test_load_full_model_prompt_tuning(self): + self.run_full_model_load_test(PromptTuningConfig(prompt_length=10)) + + def test_train_prompt_tuning(self): + self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) diff --git a/tests_adapters/test_adapter.py b/tests_adapters/test_adapter.py index d5e8d2754c..84c45c67dc 100644 --- a/tests_adapters/test_adapter.py +++ b/tests_adapters/test_adapter.py @@ -9,10 +9,29 @@ from transformers.testing_utils import torch_device +global_rng = random.Random() + + def make_config(config_class, **kwargs): return staticmethod(lambda: config_class(**kwargs)) +def ids_tensor(shape, vocab_size, rng=None, name=None): + # Creates a random int32 tensor of the shape within the vocab size + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + + class AdapterTestBase: # If not overriden by subclass, AutoModel should be used. model_class = AutoAdapterModel diff --git a/tests_adapters/test_adapter_config.py b/tests_adapters/test_adapter_config.py index 5667213683..2bce31c7e3 100644 --- a/tests_adapters/test_adapter_config.py +++ b/tests_adapters/test_adapter_config.py @@ -31,6 +31,7 @@ def test_config_immutable(self): def set_attr(config: AdapterConfig): config.non_linearity = "dummy" config.r = -1 # for LoRA + config.prompt_length = -1 # for PromptTuning for config in ADAPTER_CONFIG_MAP.values(): if isinstance(config, ConfigUnion): diff --git a/tests_adapters/test_adapter_custom_head.py b/tests_adapters/test_adapter_custom_head.py index ea37b52325..b68662bfc6 100644 --- a/tests_adapters/test_adapter_custom_head.py +++ b/tests_adapters/test_adapter_custom_head.py @@ -5,10 +5,11 @@ from adapters import AutoAdapterModel from adapters.heads import ClassificationHead, PredictionHead -from tests.test_modeling_common import ids_tensor from transformers import AutoConfig from transformers.testing_utils import require_torch, torch_device +from .test_adapter import ids_tensor + class CustomHead(PredictionHead): def __init__( diff --git a/tests_adapters/test_adapter_hub.py b/tests_adapters/test_adapter_hub.py index 28a47136aa..7ebf9acb6c 100644 --- a/tests_adapters/test_adapter_hub.py +++ b/tests_adapters/test_adapter_hub.py @@ -7,7 +7,6 @@ from adapters import ADAPTER_CONFIG_MAP, AdapterConfig, BertAdapterModel, get_adapter_config_hash from adapters.trainer import AdapterTrainer as Trainer from adapters.utils import find_in_index -from tests.test_modeling_common import ids_tensor from transformers import ( # get_adapter_config_hash, AutoModel, AutoTokenizer, @@ -19,6 +18,8 @@ ) from transformers.testing_utils import require_torch, torch_device +from .test_adapter import ids_tensor + SAMPLE_INDEX = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/hub-index.sample.json") diff --git a/tests_adapters/test_albert.py b/tests_adapters/test_albert.py index 29f8a2b583..054dd31278 100644 --- a/tests_adapters/test_albert.py +++ b/tests_adapters/test_albert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -42,6 +43,7 @@ class AlbertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_beit.py b/tests_adapters/test_beit.py index b2014c6c00..1e83c9e529 100644 --- a/tests_adapters/test_beit.py +++ b/tests_adapters/test_beit.py @@ -9,6 +9,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import VisionAdapterTestBase, make_config @@ -38,6 +39,7 @@ class BeitAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_bert.py b/tests_adapters/test_bert.py index 702e68ab9d..b4e67bc811 100644 --- a/tests_adapters/test_bert.py +++ b/tests_adapters/test_bert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class BertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_bert_generation.py b/tests_adapters/test_bert_generation.py index feb821ca0b..44cbd25f8e 100644 --- a/tests_adapters/test_bert_generation.py +++ b/tests_adapters/test_bert_generation.py @@ -12,6 +12,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -84,6 +85,7 @@ class BertGenerationAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_deberta.py b/tests_adapters/test_deberta.py index 96be88d26a..7fd80322ec 100644 --- a/tests_adapters/test_deberta.py +++ b/tests_adapters/test_deberta.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -47,6 +48,7 @@ class DebertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_debertaV2.py b/tests_adapters/test_debertaV2.py index bc436d996d..b2d564c2e4 100644 --- a/tests_adapters/test_debertaV2.py +++ b/tests_adapters/test_debertaV2.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -47,6 +48,7 @@ class DebertaV2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_distilbert.py b/tests_adapters/test_distilbert.py index d401fe220b..2634b390a5 100644 --- a/tests_adapters/test_distilbert.py +++ b/tests_adapters/test_distilbert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class DistilBertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_electra.py b/tests_adapters/test_electra.py index 5566e7be0d..a5c005f509 100644 --- a/tests_adapters/test_electra.py +++ b/tests_adapters/test_electra.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class ElectraAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_roberta.py b/tests_adapters/test_roberta.py index 5ccbb53eee..6d105ceaec 100644 --- a/tests_adapters/test_roberta.py +++ b/tests_adapters/test_roberta.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class RobertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_vit.py b/tests_adapters/test_vit.py index d84d8523ca..2de1b34300 100644 --- a/tests_adapters/test_vit.py +++ b/tests_adapters/test_vit.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import VisionAdapterTestBase, make_config @@ -39,6 +40,7 @@ class ViTAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_xlm_roberta.py b/tests_adapters/test_xlm_roberta.py index f46d9543ac..96268302f7 100644 --- a/tests_adapters/test_xlm_roberta.py +++ b/tests_adapters/test_xlm_roberta.py @@ -9,6 +9,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -36,6 +37,7 @@ class XLMRobertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, XLMRobertaAdapterTestBase, diff --git a/tests_adapters/test_xmod.py b/tests_adapters/test_xmod.py index 450c84231d..a8cd02c4d9 100644 --- a/tests_adapters/test_xmod.py +++ b/tests_adapters/test_xmod.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class XmodAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin,