diff --git a/.gitignore b/.gitignore
index b403149cc0..80ed06450d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -74,7 +74,7 @@ instance/
# Sphinx documentation
docs/_build/
-docs/_build/
+adapter_docs/_build/
# PyBuilder
target/
diff --git a/docs/classes/adapter_config.rst b/docs/classes/adapter_config.rst
index 911a3fade2..91a9f506a0 100644
--- a/docs/classes/adapter_config.rst
+++ b/docs/classes/adapter_config.rst
@@ -55,6 +55,13 @@ IA3Config
:members:
:inherited-members: Mapping
+PromptTuningConfig
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: adapters.PromptTuningConfig
+ :members:
+ :inherited-members: Mapping
+
Combined configurations
~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/methods.md b/docs/methods.md
index 04cebbaaac..06ad700e68 100644
--- a/docs/methods.md
+++ b/docs/methods.md
@@ -267,3 +267,31 @@ model.reset_adapter()
_Papers:_
- [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022)
+
+## Prompt Tuning
+Prompt Tuning is an efficient fine-tuning technique proposed by Lester et al. (2021). Prompt tuning adds tunable tokens, called soft-prompts, that are prepended to the input text.
+First, the input sequence ${x_1, x_2, \dots, x_n }$ gets embedded, resulting in the matrix $X_e \in \mathbb{R}^{n \times e}$ where $e$ is the dimension of
+the embedding space. The soft-prompts with length $p$ are represented as $P_e \in \mathbb{R}^{p \times e}$.
+$P_e$ and $X_e$ get concatenated, forming the input of the following encoder or decoder:
+
+$$
+\left[P_e; X_e\right] \in \mathbb{R}^{\left(p + n\right) \times e}
+$$
+
+The `PromptTuningConfig` has the properties:
+- `prompt_length`: to set the soft-prompts length $p$
+- `prompt_init`: to set the weight initialisation method, which is either "random_uniform" or "from_string" to initialize each prompt token with an embedding drawn from the model’s vocabulary.
+ - `prompt_init_text` as the text use for initialisation if `prompt_init="from_string"`
+- `combine`: To define if the prefix should be added before the embedded input sequence or after the BOS token
+
+To add Prompt Tuning to your model, you can use the predefined configs:
+```python
+from adapters import PromptTuningConfig
+
+config = PromptTuningConfig(prompt_length=10)
+model.add_adapter("dummy", config=config)
+```
+
+_Papers:_
+- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021)
+
diff --git a/docs/model_overview.md b/docs/model_overview.md
index 8198ea64d0..a5ba7c4e8c 100644
--- a/docs/model_overview.md
+++ b/docs/model_overview.md
@@ -10,28 +10,28 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```
-| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block |
-| --------------------------------------- | -| - | - | - | - | - | - |
-| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | |
-| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
-| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | |
-| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning |
+| --------------------------------------- | -| - | - | - | - | - | - |- |
+| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
+| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
+| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | |
+| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
+| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
(*) If the used encoder and decoder model class are supported.
diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py
index a02be1d420..2b1524734d 100644
--- a/src/adapters/__init__.py
+++ b/src/adapters/__init__.py
@@ -54,6 +54,7 @@
"ModelAdaptersConfig",
"ParBnConfig",
"PrefixTuningConfig",
+ "PromptTuningConfig",
"SeqBnConfig",
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
@@ -161,6 +162,7 @@
ModelAdaptersConfig,
ParBnConfig,
PrefixTuningConfig,
+ PromptTuningConfig,
SeqBnConfig,
SeqBnInvConfig,
StaticAdapterFusionConfig,
diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py
index 22b6c8101f..63039a8459 100644
--- a/src/adapters/configuration/adapter_config.py
+++ b/src/adapters/configuration/adapter_config.py
@@ -84,6 +84,8 @@ def _get_config_class(config_dict):
cls_new = LoRAConfig
elif architecture == "union":
cls_new = ConfigUnion
+ elif architecture == "prompt_tuning":
+ cls_new = PromptTuningConfig
else:
cls_new = BnConfig
@@ -395,6 +397,33 @@ class PrefixTuningConfig(AdapterConfig):
shared_gating: bool = True
+@dataclass(eq=False)
+class PromptTuningConfig(AdapterConfig):
+ """
+ The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf
+
+ Args:
+ prompt_length (int): The number of tokens in the prompt.
+ Defaults to 10.
+ prompt_init (str): The initialization method for the prompt. Can be either "random_uniform" or "from_string".
+ Defaults to "random_uniform".
+ prompt_init_text (str): The text to use for prompt initialization if prompt_init="from_string".
+ random_uniform_scale (float): The scale of the random uniform initialization if prompt_init="random_uniform".
+ Defaults to 0.5 as in the paper.
+ combine (str):
+ The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos".
+ Defaults to "prefix".
+ """
+
+ architecture: str = "prompt_tuning"
+
+ prompt_length: int = 10
+ prompt_init: str = "random_uniform"
+ prompt_init_text: Optional[str] = None
+ random_uniform_scale = 0.5
+ combine: str = "prefix"
+
+
@dataclass(eq=False)
class LoRAConfig(AdapterConfig):
"""
@@ -612,6 +641,7 @@ def __init__(
"compacter": CompacterConfig(),
"prefix_tuning": PrefixTuningConfig(),
"prefix_tuning_flat": PrefixTuningConfig(flat=True),
+ "prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"mam": MAMConfig(),
diff --git a/src/adapters/context.py b/src/adapters/context.py
index 784ed579ea..70e685d037 100644
--- a/src/adapters/context.py
+++ b/src/adapters/context.py
@@ -78,7 +78,13 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()
- context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]
+ context_attributes = [
+ "adapter_gating_scores",
+ "adapter_fusion_attentions",
+ "adapter_input_parallelized",
+ ]
+ # Additional used attributes not exposed to the user
+ # - prompt_tokens_length: length of the prompt tokens
def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
@@ -102,6 +108,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
+ # whether to output the context attributes
+ output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
@@ -116,7 +124,14 @@ def wrapper_func(self, *args, **kwargs):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
- return results
+
+ if output_context:
+ context_dict = ctx.__dict__
+
+ if output_context:
+ return results, context_dict
+ else:
+ return results
else:
return f(self, *args, **kwargs)
diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py
index b3116a7b9a..d45897df10 100644
--- a/src/adapters/heads/base.py
+++ b/src/adapters/heads/base.py
@@ -326,6 +326,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = kwargs.pop("labels", None)
if labels is not None:
loss_fct = CrossEntropyLoss()
+ # adjust labels for prompt tuning
+ if kwargs.get("prompt_tokens_length", 0) > 0:
+ prompt_length = kwargs.get("prompt_tokens_length")
+ prompt_labels = torch.full(
+ (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
+ )
+ labels = torch.cat((prompt_labels, labels), dim=-1)
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ (torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask),
+ dim=-1,
+ )
+
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
@@ -763,7 +776,14 @@ def _get_used_heads(self, head_name: str = None):
return head_modules
def forward_head(
- self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
+ self,
+ all_outputs,
+ head_name=None,
+ cls_output=None,
+ attention_mask=None,
+ return_dict=False,
+ context=None,
+ **kwargs
):
"""
The forward pass through a prediction head configuration. There are three ways to specify the used prediction
@@ -811,6 +831,12 @@ def _get_head_input(outputs, cls_out, batch):
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter
+ # Set prompt tokens length
+ if context is not None:
+ prompt_tokens_length = context.get("prompt_tokens_length", None)
+ if prompt_tokens_length is not None:
+ kwargs["prompt_tokens_length"] = prompt_tokens_length
+
if isinstance(self.active_head, BatchSplit):
if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]:
raise ValueError(
diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py
index 3e0cda610a..bf91e5be08 100644
--- a/src/adapters/heads/language_modeling.py
+++ b/src/adapters/heads/language_modeling.py
@@ -1,3 +1,4 @@
+import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput
@@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = labels[..., 1:].contiguous()
else:
logits_for_loss = lm_logits
+
+ # adjust labels for prompt tuning
+ if kwargs.get("prompt_tokens_length", 0) > 0:
+ prompt_length = kwargs.get("prompt_tokens_length")
+ prompt_labels = torch.full(
+ (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
+ )
+ labels = torch.cat((prompt_labels, labels), dim=-1)
+
loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))
if return_dict:
diff --git a/src/adapters/loading.py b/src/adapters/loading.py
index f189a120a8..8e22cfc128 100644
--- a/src/adapters/loading.py
+++ b/src/adapters/loading.py
@@ -307,6 +307,7 @@ def filter_func(self, adapter_name):
or ".prefix_tunings.{}.".format(adapter_name) in x
or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
+ or ".prompt_tunings.{}.".format(adapter_name) in x
)
# This dict maps the original weight names to the currently used equivalents.
diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py
index 79d18500ec..2489d445b2 100644
--- a/src/adapters/methods/adapter_layer_base.py
+++ b/src/adapters/methods/adapter_layer_base.py
@@ -346,7 +346,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl:
# sequentially feed different parts of the blown-up batch into different adapters
children_states = []
for i, child in enumerate(adapter_setup):
- # compute ids of sequences thet should be passed to the ith adapter
+ # compute ids of sequences that should be passed to the ith adapter
batch_idx = (
sum(adapter_setup.batch_sizes[:i]),
sum(adapter_setup.batch_sizes[: i + 1]),
diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py
new file mode 100644
index 0000000000..fd9db60e0a
--- /dev/null
+++ b/src/adapters/methods/prompt_tuning.py
@@ -0,0 +1,228 @@
+# https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py
+
+import math
+from typing import Callable, Dict, List, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from transformers import AutoTokenizer
+from transformers.configuration_utils import PretrainedConfig
+
+from ..composition import AdapterCompositionBlock
+from ..configuration import ModelAdaptersConfig, PromptTuningConfig
+from ..context import ForwardContext
+from .adapter_layer_base import AdapterLayerBase
+
+
+class PromptTuning(nn.Module):
+ """Generate a Prompt and concatenate it with the input.
+
+ This is the training time version of prompting a model. Calling the injected `prompt` module will generate your
+ unbatched prompt. This model then replicates it for the batched input and concatenates them together.
+
+ Attributes:
+ prompt: The module that actually generates the unbatched prompt.
+ combine: A function that combines the prompt and the embedded input.
+ """
+
+ prompt: nn.Module
+ combination_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
+
+ def __init__(
+ self,
+ adapter_name: str,
+ prompt_tuning_config: PromptTuningConfig,
+ model_config: PretrainedConfig,
+ base_model_embeddings: nn.Module,
+ ):
+ super().__init__()
+
+ self.name = adapter_name
+ self.model_config = model_config
+ self.prompt_tuning_config = prompt_tuning_config
+
+ embedding_size = getattr(model_config, "embedding_size", model_config.hidden_size)
+
+ self.prompt_embedding = nn.Embedding(
+ num_embeddings=prompt_tuning_config.prompt_length, embedding_dim=embedding_size
+ )
+ # Initialize prompt tokens
+ self.prompt_tokens = torch.arange(prompt_tuning_config.prompt_length).long()
+
+ self._init_prompt_embedding(base_model_embeddings)
+
+ if prompt_tuning_config.combine == "prefix":
+ self.combination_fn = lambda prompt, embedded_input: torch.cat([prompt, embedded_input], dim=1)
+ elif prompt_tuning_config.combine == "prefix_after_bos":
+ self.combination_fn = lambda prompt, embedded_input: torch.cat(
+ [embedded_input[:, 0, np.newaxis], prompt, embedded_input[:, 1:]], dim=1
+ )
+ else:
+ raise ValueError(
+ f"Unknown combination function: {prompt_tuning_config.combine}. "
+ "Must be one of 'prefix' or 'prefix_after_bos'."
+ )
+
+ def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None:
+ if self.prompt_tuning_config.prompt_init == "random_uniform":
+ nn.init.uniform_(
+ self.prompt_embedding.weight,
+ a=-self.prompt_tuning_config.random_uniform_scale,
+ b=self.prompt_tuning_config.random_uniform_scale,
+ )
+
+ elif self.prompt_tuning_config.prompt_init == "from_string":
+ tokenizer = AutoTokenizer.from_pretrained(self.model_config.tokenizer_name_or_path)
+ prompt_length = self.prompt_tuning_config.prompt_length
+ prompt_text = self.prompt_tuning_config.prompt_init_text
+ if prompt_text is None:
+ raise ValueError("Prompt text must be provided when using prompt_init='from_string'.")
+
+ tokenized_prompt_text: list[int] = tokenizer(prompt_text)["input_ids"] # type: ignore
+
+ # If the prompt text tokens are shorter than the prompt length, we repeat the prompt text tokens until we reach the prompt length
+ if len(tokenized_prompt_text) < prompt_length:
+ num_reps = math.ceil(prompt_length / len(tokenized_prompt_text))
+ tokenized_prompt_text = tokenized_prompt_text * num_reps
+
+ # Adjust length of prompt text tokens to match prompt_length
+ tokenized_prompt_text = tokenized_prompt_text[:prompt_length]
+
+ # Initialize prompt embedding with tokenized prompt text
+ word_embedding_weights = base_model_embeddings(torch.LongTensor(tokenized_prompt_text)).detach().clone()
+ word_embedding_weights = word_embedding_weights.to(torch.float32)
+ self.prompt_embedding.weight = nn.Parameter(word_embedding_weights)
+
+ else:
+ raise ValueError(f"Unknown prompt initialization: {self.prompt_tuning_config.prompt_init}")
+
+ def forward(self, embedded_input):
+ # Compute prompt embedding
+ self.prompt_tokens = self.prompt_tokens.to(embedded_input.device)
+ prompt = self.prompt_embedding(self.prompt_tokens)
+
+ # Prompt to batch size
+ batch_size = embedded_input.shape[0]
+ prompt = torch.tile(torch.unsqueeze(prompt, dim=0), [batch_size] + [1 for _ in prompt.shape])
+
+ # Merge prompt and input
+ output = self.combination_fn(prompt, embedded_input)
+
+ # Adapt attention mask
+ prefix_attention_mask_length = self.prompt_tuning_config.prompt_length
+
+ return output, prefix_attention_mask_length
+
+
+class PromptTuningLayer(AdapterLayerBase, nn.Module):
+ """
+ Prompt Tuning implementation.
+
+ Args:
+ model_config: The model configuration.
+ adapters_config: The adapter configuration.
+ base_model_embeddings:
+ The embedding layer of the base model (used to initialize the prompt embedding if
+ prompt_init='from_string').
+ """
+
+ adapter_modules_name = "prompt_tunings"
+
+ def __init__(
+ self,
+ model_config: PretrainedConfig,
+ adapters_config: ModelAdaptersConfig,
+ base_model_embeddings: nn.Module,
+ ):
+ super().__init__()
+ self.model_config = model_config
+ self.adapters_config = adapters_config
+ self.base_model_embeddings = base_model_embeddings
+ self.prompt_tunings = nn.ModuleDict()
+
+ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
+ # ignore layer_idx as prompt tunings are only added after the embedding layer
+ prompt_tuning_config = self.adapters_config.match(
+ adapter_name,
+ config_type=PromptTuningConfig,
+ )
+
+ if prompt_tuning_config is not None:
+ adapter = PromptTuning(
+ adapter_name=adapter_name,
+ prompt_tuning_config=prompt_tuning_config, # type: ignore
+ model_config=self.model_config,
+ base_model_embeddings=self.base_model_embeddings,
+ )
+ adapter.train(self.training) # make sure training mode is consistent
+ self.prompt_tunings[adapter_name] = adapter
+ return True
+
+ return False
+
+ def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool:
+ # add new adapter
+ if self.add_adapter(adapter_name, -1):
+ # average weights
+ avg_state_dict = {}
+ for name, weight in input_adapters.items():
+ if name in self.prompt_tunings:
+ module = self.prompt_tunings[name]
+ for k, v in module.state_dict().items():
+ if k in avg_state_dict:
+ avg_state_dict[k] += weight * v
+ else:
+ avg_state_dict[k] = weight * v
+ else:
+ self.delete_adapter(adapter_name) # clean up before raising error
+ raise ValueError("Adapter {} not found.".format(name))
+ # load averaged weights
+ self.prompt_tunings[adapter_name].load_state_dict(avg_state_dict)
+ return True
+
+ return False
+
+ def delete_adapter(self, adapter_name: str):
+ if adapter_name in self.prompt_tunings:
+ del self.prompt_tunings[adapter_name]
+
+ def add_fusion_layer(self, adapter_names: Union[List, str]):
+ pass # not applicable to prompt tuning
+
+ def delete_fusion_layer(self, adapter_names: Union[List, str]):
+ pass # not applicable to prompt tuning
+
+ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool):
+ if unfreeze_adapters:
+ for prompt_tuning_name in adapter_setup.flatten():
+ if prompt_tuning_name in self.prompt_tunings:
+ for param in self.prompt_tunings[prompt_tuning_name].parameters():
+ param.requires_grad = True
+
+ def freeze_adapter(self, adapter_name: str, freeze: bool = True):
+ if adapter_name in self.prompt_tunings:
+ self.prompt_tunings[adapter_name].train(not freeze)
+ for param in self.prompt_tunings[adapter_name].parameters():
+ param.requires_grad = not freeze
+
+ def get_adapter(self, adapter_name):
+ if adapter_name in self.prompt_tunings:
+ return self.prompt_tunings[adapter_name]
+ else:
+ return None
+
+ def forward(self, hidden_states: torch.Tensor):
+ prefix_attention_mask_length = None
+ adapter_setup = self.get_active_setup()
+ if adapter_setup is not None and len(adapter_setup) > 0:
+ first_adapter = adapter_setup.first()
+ if first_adapter in self.prompt_tunings:
+ hidden_states, prefix_attention_mask_length = self.prompt_tunings[first_adapter](hidden_states)
+
+ context = ForwardContext.get_context()
+ if context is not None:
+ context.prompt_tokens_length = prefix_attention_mask_length
+
+ return hidden_states
diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py
index 078086ece8..a4a1b8d17c 100644
--- a/src/adapters/model_mixin.py
+++ b/src/adapters/model_mixin.py
@@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from os.path import join
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -22,6 +22,7 @@
from .methods.lora import LoRALayer
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
+from .methods.prompt_tuning import PromptTuningLayer
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config
@@ -40,22 +41,6 @@ def init_adapters(self, model_config, adapters_config, **kwargs):
if hasattr(super(), "init_adapters"):
super().init_adapters(self.config, self.adapters_config, **kwargs)
- self.hook_after_embeddings(self._hook_fn)
-
- def _hook_fn(self, module, args, output):
- new_output = self.invertible_adapters_forward(output)
- return new_output
-
- def hook_after_embeddings(self, hook_fn: Callable):
- """
- Hook a function to be called after the embeddings have been computed. The default implementation does nothing.
- Override this method to add a hook.
-
- Args:
- hook_fn (Callable): The function to be called after the embeddings have been computed.
- """
- pass
-
def add_invertible_adapter(self, adapter_name: str) -> bool:
"""
Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an
@@ -132,13 +117,31 @@ def enable_invertible_adapters(self, adapter_names):
def invertible_adapters_forward(self, hidden_states, rev=False):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
- if self.adapters_config.active_setup is not None and len(self.adapters_config.active_setup) > 0:
- first_adapter = self.adapters_config.active_setup.first()
+ adapter_setup = self._get_active_setup()
+ if adapter_setup is not None and len(adapter_setup) > 0:
+ first_adapter = adapter_setup.first()
if first_adapter in self.invertible_adapters:
hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev)
-
return hidden_states
+ def _get_active_setup(self):
+ if hasattr(self, "adapters_config"):
+ # First check current context before falling back to defined setup
+ context = AdapterSetup.get_context()
+ if context is not None:
+ adapter_setup = context.adapter_setup
+ else:
+ adapter_setup = self.adapters_config.active_setup
+ else:
+ adapter_setup = None
+ skip_adapters = adapter_setup is None or (
+ self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers
+ )
+ if not skip_adapters and (len(adapter_setup.flatten()) > 0):
+ return adapter_setup
+ else:
+ return None
+
class InvertibleAdaptersWrapperMixin:
"""
@@ -364,6 +367,10 @@ def loaded_embeddings(self):
class ModelAdaptersMixin(PushAdapterToHubMixin, ABC):
"""Mixin for transformer models adding support for loading/ saving adapters."""
+ add_base_adapters = False
+ support_prompt_tuning = True # If False, the prompt tuning layer is not added to the model. If True, the prompt tuning layer is added if add_base_adapters is True.
+ _tied_weights_keys = ["prompt_tuning.base_model_embeddings.*"]
+
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
@@ -400,6 +407,11 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr
self.base_model.prefix_tuning = PrefixTuningPool(self.config, self.adapters_config)
self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer))
+ # Add Prompt Tuning
+ if self.add_base_adapters:
+ if self.support_prompt_tuning:
+ self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings())
+
# Initialize adapters from config
for adapter_name in self.adapters_config:
self._add_adapter_weights(adapter_name)
@@ -430,12 +442,23 @@ def apply_to_adapter_layers(self, fn):
if isinstance(module, AdapterLayerBase):
fn(i, module)
+ def apply_to_basemodel_childs(self, fn):
+ """
+ Applies a function to all direct childs of the model if they are a instance of AdapterLayerBase.
+ """
+ if self.add_base_adapters:
+ for module in self.base_model.children():
+ if isinstance(module, AdapterLayerBase):
+ # These childs don't have a layer index so we pass -1
+ fn(-1, module)
+
def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False):
"""Sets the model into mode for training the given adapters."""
self.train()
self.freeze_model(True)
adapter_setup = parse_composition(adapter_setup)
self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False))
+ self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, True, False))
for adapter_name in adapter_setup:
if adapter_name in self.base_model.shared_parameters:
for param in self.base_model.shared_parameters[adapter_name].values():
@@ -462,6 +485,7 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc
self.freeze_model(True)
adapter_setup = parse_composition(adapter_setup)
self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, unfreeze_adapters, True))
+ self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, unfreeze_adapters, True))
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_setup)
# TODO implement fusion for invertible adapters
@@ -545,6 +569,8 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
def _add_adapter_weights(self, adapter_name: str):
"""Helper method that performs the actual parameter additions when adding a new adapter."""
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
+ self.apply_to_basemodel_childs(lambda i, child: child.add_adapter(adapter_name, i))
+
# PHM Layer
if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"):
adapter_module = list(self.get_adapter(adapter_name)[0].values())[0]
@@ -624,6 +650,7 @@ def add_adapter_fusion(
self.delete_adapter_fusion(adapter_names)
self.adapters_config.add_fusion(adapter_names, config=config)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names))
+ self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names))
if set_active:
if not isinstance(adapter_names, list):
adapter_names = adapter_names.split(",")
@@ -641,11 +668,13 @@ def delete_adapter(self, adapter_name: str):
return
del self.adapters_config.adapters[adapter_name]
self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name))
+ self.apply_to_basemodel_childs(lambda i, child: child.delete_adapter(adapter_name))
# PHM Layer
if adapter_name in self.base_model.shared_parameters:
del self.base_model.shared_parameters[adapter_name]
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
self.delete_invertible_adapter(adapter_name)
+
# Reset active adapters if this was the only active adapter
if self.active_adapters == Stack(adapter_name):
self.active_adapters = None
@@ -671,6 +700,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]):
return
del self.adapters_config.fusions[adapter_fusion_name]
self.apply_to_adapter_layers(lambda i, layer: layer.delete_fusion_layer(adapter_fusion_name))
+ self.apply_to_basemodel_childs(lambda i, child: child.delete_fusion_layer(adapter_fusion_name))
# Reset active adapters if this was the active setup
if self.active_adapters == adapter_names:
self.active_adapters = None
@@ -966,6 +996,11 @@ def get_adapter(self, name) -> dict:
) and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]
+ if self.support_prompt_tuning:
+ prompt_tuning = self.prompt_tuning.get_adapter(name)
+ if prompt_tuning is not None:
+ destination[-1]["prompt"] = prompt_tuning
+
# use a custom index to ensure numbering is from 0 to N layers
for i, (_, layer) in enumerate(self.iter_layers()):
for module in layer.modules():
@@ -1115,6 +1150,7 @@ def average_adapter(
input_adapters = {name: weight / sum_weights for name, weight in zip(adapter_list, weights)}
try:
self.apply_to_adapter_layers(lambda i, layer: layer.average_adapter(adapter_name, input_adapters))
+ self.apply_to_basemodel_childs(lambda i, child: child.average_adapter(adapter_name, input_adapters))
# PHM Layer
if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"):
self._average_shared_parameters(adapter_name, input_adapters)
@@ -1226,6 +1262,16 @@ def save_pretrained(
@inherit_doc
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
+ add_base_adapters = True
+
+ def post_embedding_forward(self, module, args, embedding_output):
+ if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+
+ embedding_output = self.prompt_tuning.forward(embedding_output)
+
+ return embedding_output
+
@ForwardContext.wrap
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py
index 9a1f45ed2e..8261e68760 100644
--- a/src/adapters/models/albert/adapter_model.py
+++ b/src/adapters/models/albert/adapter_model.py
@@ -64,7 +64,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.albert(
+ outputs, context = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -77,7 +77,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py
index ff9ef19fe3..4d84abd5f6 100644
--- a/src/adapters/models/albert/mixin_albert.py
+++ b/src/adapters/models/albert/mixin_albert.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -51,6 +51,8 @@ def init_adapters(self, model_config, adapters_config):
for _, layer in self.iter_layers():
self._set_layer_hook_for_parallel(layer)
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
def _set_layer_hook_for_parallel(self, layer: nn.Module):
def hook(module, input):
adjust_tensors_for_parallel_(input[0], input[1])
@@ -64,6 +66,3 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for albertLayer in albertLayerGroup.albert_layers:
yield i, albertLayer
i += 1
-
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.embeddings.register_forward_hook(hook_fn)
diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py
index 7f5294cad7..f620240c63 100644
--- a/src/adapters/models/albert/modeling_albert.py
+++ b/src/adapters/models/albert/modeling_albert.py
@@ -24,6 +24,7 @@
from transformers.pytorch_utils import apply_chunking_to_forward
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin
@@ -35,6 +36,8 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py
index 3fc3dfd73c..ddb94e6fe9 100644
--- a/src/adapters/models/bart/adapter_model.py
+++ b/src/adapters/models/bart/adapter_model.py
@@ -26,7 +26,10 @@
"BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING
)
class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel):
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+ _tied_weights_keys = [
+ "encoder.embed_tokens.weight",
+ "decoder.embed_tokens.weight",
+ ]
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
@@ -76,7 +79,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False
- outputs = self.model(
+ outputs, context = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
@@ -95,7 +98,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py
index 28e7b3ac77..d269d72b43 100644
--- a/src/adapters/models/bart/mixin_bart.py
+++ b/src/adapters/models/bart/mixin_bart.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Optional, Tuple
+from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
@@ -34,6 +34,7 @@ class BartEncoderLayerAdaptersMixin:
"""Adds adapters to the BartEncoderLayer module of BART."""
def init_adapters(self, model_config, adapters_config):
+ self.adapters_config = adapters_config
# Wrap layers for LoRA
self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config)
self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config)
@@ -58,8 +59,7 @@ def init_adapters(self, model_config, adapters_config):
class BartEncoderAdaptersMixin(InvertibleAdaptersMixin):
"""Adds adapters to the BartEncoder module of BART."""
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.layernorm_embedding.register_forward_hook(hook_fn)
+ pass
class BartDecoderAdaptersMixin:
@@ -76,6 +76,11 @@ class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMi
"""Adds adapters to the BartModel class."""
invertible_adapters_base_name = "encoder"
+ support_prompt_tuning = False
+
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
+ self.encoder.layernorm_embedding.register_forward_hook(self.post_embedding_forward)
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
@@ -87,6 +92,11 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.decoder.layers):
yield i, layer
+ def post_embedding_forward(self, module, args, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
+
class BartDecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the BartDecoderWrapper class."""
diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py
index 22d3e6aa4d..ceeda7b82c 100644
--- a/src/adapters/models/beit/adapter_model.py
+++ b/src/adapters/models/beit/adapter_model.py
@@ -47,7 +47,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.beit(
+ outputs, context = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
@@ -57,7 +57,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py
index 536048e669..608d9c5cc4 100644
--- a/src/adapters/models/beit/mixin_beit.py
+++ b/src/adapters/models/beit/mixin_beit.py
@@ -47,6 +47,7 @@ class BeitModelAdaptersMixin(ModelBaseAdaptersMixin):
def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py
index 4ff1aaf61a..02ad9411c4 100644
--- a/src/adapters/models/bert/adapter_model.py
+++ b/src/adapters/models/bert/adapter_model.py
@@ -66,7 +66,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.bert(
+ outputs, context = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -79,7 +79,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py
index 3cf5a6e1ff..4aa079240c 100644
--- a/src/adapters/models/bert/mixin_bert.py
+++ b/src/adapters/models/bert/mixin_bert.py
@@ -1,5 +1,5 @@
import logging
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -77,6 +77,9 @@ def init_adapters(self, model_config, adapters_config):
for _, layer in self.iter_layers():
self._set_layer_hook_for_parallel(layer)
+ # Register hook for post embedding forward
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
def _set_layer_hook_for_parallel(self, layer: nn.Module):
def hook(module, input):
adjust_tensors_for_parallel_(input[0], input[1])
@@ -87,6 +90,3 @@ def hook(module, input):
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer
-
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.embeddings.register_forward_hook(hook_fn)
diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py
index 692605610a..ea60b6f5dc 100644
--- a/src/adapters/models/bert/modeling_bert.py
+++ b/src/adapters/models/bert/modeling_bert.py
@@ -26,6 +26,7 @@
from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -40,6 +41,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py
index c251af1517..1fe0152a6a 100644
--- a/src/adapters/models/bert_generation/adapter_model.py
+++ b/src/adapters/models/bert_generation/adapter_model.py
@@ -62,7 +62,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.bert(
+ outputs, context = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -78,7 +78,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py
index 8381ccf2bb..f0ef9a35ed 100644
--- a/src/adapters/models/bert_generation/modeling_bert_generation.py
+++ b/src/adapters/models/bert_generation/modeling_bert_generation.py
@@ -28,6 +28,7 @@
)
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -52,6 +53,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/models/clip/adapter_model.py b/src/adapters/models/clip/adapter_model.py
index 6191cd3001..5aa15d417b 100644
--- a/src/adapters/models/clip/adapter_model.py
+++ b/src/adapters/models/clip/adapter_model.py
@@ -18,6 +18,8 @@
@add_start_docstrings(CLIP_START_DOCSTRING)
class CLIPAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, CLIPPreTrainedModel):
+ _tied_weights_keys = [] # needs to be empty since CLIP does not yet support prompt tuning
+
def __init__(self, config):
super().__init__(config)
@@ -44,7 +46,7 @@ def forward(
output_adapter_fusion_attentions=False,
**kwargs
):
- outputs = self.clip(
+ outputs, context = self.clip(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
@@ -56,7 +58,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py
index 02469974f5..020d88f57b 100644
--- a/src/adapters/models/clip/mixin_clip.py
+++ b/src/adapters/models/clip/mixin_clip.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -60,14 +60,23 @@ def hook(module, input):
class CLIPTextTransformerAdaptersMixin(InvertibleAdaptersMixin):
"""Adds adapters to the CLIPTextTransformer module of CLIP."""
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.embeddings.register_forward_hook(hook_fn)
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
+
+ # Register hook for post embedding forward
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
+ def post_embedding_forward(self, module, args, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
class CLIPTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the CLIPTextModel class."""
invertible_adapters_base_name = "text_model"
+ support_prompt_tuning = False
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.text_model.encoder.layers):
@@ -77,6 +86,8 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
class CLIPVisionModelAdaptersMixin(ModelBaseAdaptersMixin):
"""Adds adapters to the a CLIPVisionModel class."""
+ support_prompt_tuning = False
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.vision_model.encoder.layers):
yield i, layer
@@ -86,6 +97,7 @@ class CLIPModelAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWr
"""Adds adapters to the CLIPModel class."""
invertible_adapters_base_name = "text_model"
+ support_prompt_tuning = False
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.text_model.encoder.layers):
diff --git a/src/adapters/models/deberta/adapter_model.py b/src/adapters/models/deberta/adapter_model.py
index 0d74e58593..4b44991d66 100644
--- a/src/adapters/models/deberta/adapter_model.py
+++ b/src/adapters/models/deberta/adapter_model.py
@@ -57,7 +57,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.deberta(
+ outputs, context = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -69,7 +69,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py
index 71b7f9dc2a..1feca72b4a 100644
--- a/src/adapters/models/deberta/modeling_deberta.py
+++ b/src/adapters/models/deberta/modeling_deberta.py
@@ -25,6 +25,7 @@
)
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin
from .mixin_deberta import DebertaSelfAttentionAdaptersMixin
@@ -94,6 +95,9 @@ def forward(
"""
+ attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore
+ attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore
+
if query_states is None:
qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
diff --git a/src/adapters/models/deberta_v2/adapter_model.py b/src/adapters/models/deberta_v2/adapter_model.py
index bc2f6e6ed2..a980d99177 100644
--- a/src/adapters/models/deberta_v2/adapter_model.py
+++ b/src/adapters/models/deberta_v2/adapter_model.py
@@ -60,7 +60,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.deberta(
+ outputs, context = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -72,7 +72,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py
index aa8945000f..56d6fec448 100644
--- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py
+++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py
@@ -25,6 +25,7 @@
)
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin
from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin
@@ -88,9 +89,10 @@ def forward(
rel_embeddings (`torch.FloatTensor`):
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
\\text{max_relative_positions}\\), *hidden_size*].
-
-
"""
+ attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore
+ attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore
+
if query_states is None:
query_states = hidden_states
query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads)
diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py
index 9a2294ac89..ee7fb57bf9 100644
--- a/src/adapters/models/distilbert/adapter_model.py
+++ b/src/adapters/models/distilbert/adapter_model.py
@@ -85,7 +85,7 @@ def forward(
else None
)
- distilbert_output = self.distilbert(
+ distilbert_output, context = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
@@ -96,7 +96,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
outputs = self.forward_head(
distilbert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs
diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py
index 111733c2f0..3301cba949 100644
--- a/src/adapters/models/distilbert/mixin_distilbert.py
+++ b/src/adapters/models/distilbert/mixin_distilbert.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -44,15 +44,10 @@ def forward(self, *args, **kwargs):
class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the DistilBert module."""
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.transformer.layer):
yield i, layer
-
- def _hook_fn(self, module, input):
- new_input = self.invertible_adapters_forward(input)
- return new_input
-
- def hook_after_embeddings(self, hook_fn: Callable):
- # PyTorch's built-in pre-forward hook does not pass the input ids.
- # Therefore, we need to use a custom hook.
- self.transformer.pre_forward_fn = hook_fn
diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py
index e0aee4e1b9..cbd501942c 100644
--- a/src/adapters/models/distilbert/modeling_distilbert.py
+++ b/src/adapters/models/distilbert/modeling_distilbert.py
@@ -28,6 +28,7 @@
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock
from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin
@@ -121,6 +122,8 @@ def forward(
torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
"""
adjust_tensors_for_parallel_(x, attn_mask)
+ attn_mask = prefix_attention_mask(attn_mask, dim=1, prefix_value=1) # type: ignore
+
# Self-Attention
sa_output = self.attention(
query=x,
diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py
index 2d7994d3a4..6dbd02569f 100644
--- a/src/adapters/models/electra/adapter_model.py
+++ b/src/adapters/models/electra/adapter_model.py
@@ -66,7 +66,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.electra(
+ outputs, context = self.electra(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -79,7 +79,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
head_inputs = outputs
diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py
index cbe4277ec9..53b5ed29ac 100644
--- a/src/adapters/models/electra/modeling_electra.py
+++ b/src/adapters/models/electra/modeling_electra.py
@@ -7,6 +7,7 @@
from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -21,6 +22,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
index ba2df24a84..50257d1536 100644
--- a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
+++ b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
@@ -17,6 +17,8 @@ class EncoderDecoderModelAdaptersMixin(
):
"""Adds adapters to the EncoderDecoderModel class."""
+ support_prompt_tuning = False
+
def init_adapters(self, model_config, adapters_config):
if not isinstance(self.encoder, ModelAdaptersMixin) or not isinstance(self.decoder, ModelAdaptersMixin):
return
diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py
index b4e1b53d54..cc5709b53e 100644
--- a/src/adapters/models/gpt2/adapter_model.py
+++ b/src/adapters/models/gpt2/adapter_model.py
@@ -33,6 +33,8 @@
GPT2_START_DOCSTRING,
)
class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel):
+ _tied_weights_keys = [] # needs to be empty since GPT2 does not yet support prompt tuning
+
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
@@ -68,7 +70,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.transformer(
+ outputs, context = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
@@ -85,7 +87,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
batch_size = outputs[0].shape[0]
diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py
index ce88136a92..bd142eb470 100644
--- a/src/adapters/models/gpt2/mixin_gpt2.py
+++ b/src/adapters/models/gpt2/mixin_gpt2.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -54,9 +54,19 @@ def init_adapters(self, model_config, adapters_config):
class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
+ support_prompt_tuning = False
+
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
+
+ # Register hook for post embedding forward
+ self.drop.register_forward_hook(self.post_embedding_forward)
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.base_model.h):
yield i, layer
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.drop.register_forward_hook(hook_fn)
+ def post_embedding_forward(self, module, args, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py
index 625cd0febc..a4bd8f32a1 100644
--- a/src/adapters/models/gptj/adapter_model.py
+++ b/src/adapters/models/gptj/adapter_model.py
@@ -33,6 +33,8 @@
GPTJ_START_DOCSTRING,
)
class GPTJAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel):
+ _tied_weights_keys = [] # needs to be empty since GPT-J does not yet support prompt tuning
+
def __init__(self, config):
super().__init__(config)
self.transformer = GPTJModel(config)
@@ -66,7 +68,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.transformer(
+ outputs, context = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
@@ -81,7 +83,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
batch_size = outputs[0].shape[0]
diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py
index 7e4e771cba..cc20e63cc3 100644
--- a/src/adapters/models/gptj/mixin_gptj.py
+++ b/src/adapters/models/gptj/mixin_gptj.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -38,12 +38,19 @@ def init_adapters(self, model_config, adapters_config):
class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
+ support_prompt_tuning = False
+
def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)
+ # Register hook for post embedding forward
+ self.drop.register_forward_hook(self.post_embedding_forward)
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.base_model.h):
yield i, layer
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.drop.register_forward_hook(hook_fn)
+ def post_embedding_forward(self, module, args, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py
index 43cec8abbf..7b9ce69083 100644
--- a/src/adapters/models/llama/adapter_model.py
+++ b/src/adapters/models/llama/adapter_model.py
@@ -32,6 +32,8 @@
LLAMA_START_DOCSTRING,
)
class LlamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel):
+ _tied_weights_keys = [] # needs to be empty since LLaMA does not yet support prompt tuning
+
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
@@ -68,7 +70,7 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.model(
+ outputs, context = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
@@ -81,7 +83,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
batch_size = outputs[0].shape[0]
diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py
index 3caf66e544..aae339433c 100644
--- a/src/adapters/models/llama/mixin_llama.py
+++ b/src/adapters/models/llama/mixin_llama.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -28,9 +28,19 @@ def init_adapters(self, model_config, adapters_config):
class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
+ support_prompt_tuning = False
+
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
+
+ # Register hook for post embedding forward
+ self.embed_tokens.register_forward_hook(self.post_embedding_forward)
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.layers):
yield i, layer
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.embed_tokens.register_forward_hook(hook_fn)
+ def post_embedding_forward(self, module, args, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py
index ae86c35c17..5b57eb2cb0 100644
--- a/src/adapters/models/mbart/adapter_model.py
+++ b/src/adapters/models/mbart/adapter_model.py
@@ -26,7 +26,10 @@
"MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING
)
class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel):
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+ _tied_weights_keys = [
+ "encoder.embed_tokens.weight",
+ "decoder.embed_tokens.weight",
+ ]
def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs)
@@ -76,7 +79,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False
- outputs = self.model(
+ outputs, context = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
@@ -95,7 +98,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py
index 13bf8b8102..3a08f33639 100644
--- a/src/adapters/models/roberta/adapter_model.py
+++ b/src/adapters/models/roberta/adapter_model.py
@@ -66,7 +66,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.roberta(
+ outputs, context = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -79,7 +79,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py
index e33b7e7ca3..8a79d4effb 100644
--- a/src/adapters/models/roberta/modeling_roberta.py
+++ b/src/adapters/models/roberta/modeling_roberta.py
@@ -25,6 +25,7 @@
from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -40,6 +41,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py
index 66441727c7..d981815bd9 100644
--- a/src/adapters/models/t5/adapter_model.py
+++ b/src/adapters/models/t5/adapter_model.py
@@ -22,7 +22,10 @@
@add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING)
class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel):
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+ _tied_weights_keys = [
+ "encoder.embed_tokens.weight",
+ "decoder.embed_tokens.weight",
+ ]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
@@ -82,7 +85,7 @@ def forward(
# decoder_input_ids from input_ids if no decoder_input_ids are provided
decoder_input_ids = self._shift_right(input_ids)
- model_output = self.transformer(
+ model_output, context = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
@@ -101,7 +104,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
sequence_output = model_output[0]
# ToDo move head to device for parallel forward pass
diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py
index 244f5d4335..1aa6227d88 100644
--- a/src/adapters/models/t5/mixin_t5.py
+++ b/src/adapters/models/t5/mixin_t5.py
@@ -83,11 +83,17 @@ def init_adapters(self, model_config, adapters_config):
if not self.is_decoder:
InvertibleAdaptersMixin.init_adapters(self, self.config, adapters_config)
+ def post_embedding_forward(self, embedding_output):
+ embedding_output = self.invertible_adapters_forward(embedding_output)
+ # Prompt tuning not yet supported
+ return embedding_output
+
class T5ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the T5Model class."""
invertible_adapters_base_name = "encoder"
+ support_prompt_tuning = False
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
global_i = 0
diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py
index 19064f58b2..b366b9cebe 100644
--- a/src/adapters/models/t5/modeling_t5.py
+++ b/src/adapters/models/t5/modeling_t5.py
@@ -356,7 +356,7 @@ def forward(
hidden_states = self.dropout(inputs_embeds)
if not self.is_decoder:
- hidden_states = self.invertible_adapters_forward(hidden_states)
+ hidden_states = self.post_embedding_forward(hidden_states)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
diff --git a/src/adapters/models/vit/adapter_model.py b/src/adapters/models/vit/adapter_model.py
index 33eaaf2ea0..254a5ab0d7 100644
--- a/src/adapters/models/vit/adapter_model.py
+++ b/src/adapters/models/vit/adapter_model.py
@@ -47,7 +47,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.vit(
+ outputs, context = self.vit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
@@ -57,7 +57,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py
index 2f9962a9d8..9b4a92e45b 100644
--- a/src/adapters/models/vit/mixin_vit.py
+++ b/src/adapters/models/vit/mixin_vit.py
@@ -52,6 +52,9 @@ class ViTModelAdaptersMixin(ModelBaseAdaptersMixin):
def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)
+ # Register hook for post embedding forward
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer
diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py
index ab1ca81f79..33963d5f1e 100644
--- a/src/adapters/models/xlm_roberta/adapter_model.py
+++ b/src/adapters/models/xlm_roberta/adapter_model.py
@@ -68,7 +68,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.roberta(
+ outputs, context = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -81,7 +81,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
index 5f18c9f70e..959f75c0e7 100644
--- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
+++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
@@ -29,6 +29,7 @@
)
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -44,6 +45,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py
index 31ca7acd3b..d61578f158 100644
--- a/src/adapters/models/xmod/adapter_model.py
+++ b/src/adapters/models/xmod/adapter_model.py
@@ -73,7 +73,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.roberta(
+ outputs, context = self.roberta(
input_ids,
lang_ids=lang_ids,
attention_mask=attention_mask,
@@ -87,7 +87,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
+ output_context=True,
)
+ # required e.g. for prompt tuning in all models
+ kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
diff --git a/src/adapters/models/xmod/mixin_xmod.py b/src/adapters/models/xmod/mixin_xmod.py
index eac7e4b418..bef4371f34 100644
--- a/src/adapters/models/xmod/mixin_xmod.py
+++ b/src/adapters/models/xmod/mixin_xmod.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, Tuple
+from typing import Iterable, Tuple
import torch.nn as nn
@@ -25,6 +25,9 @@ def init_adapters(self, model_config, adapters_config):
for _, layer in self.iter_layers():
del layer.output.adapter_modules
+ # Register hook for post embedding forward
+ self.embeddings.register_forward_hook(self.post_embedding_forward)
+
def _set_layer_hook_for_parallel(self, layer: nn.Module):
def hook(module, input):
# hook[1] is lang_ids tensor
@@ -37,9 +40,6 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer
- def hook_after_embeddings(self, hook_fn: Callable):
- return self.embeddings.register_forward_hook(hook_fn)
-
def forward(self, *args, **kwargs):
if "lang_ids" in kwargs and kwargs["lang_ids"] is not None:
raise ValueError(
diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py
index 4a2269fbae..e91131e90e 100644
--- a/src/adapters/models/xmod/modeling_xmod.py
+++ b/src/adapters/models/xmod/modeling_xmod.py
@@ -24,6 +24,7 @@
from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput
from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
+from ...utils import prefix_attention_mask
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -39,6 +40,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
+ attention_mask = prefix_attention_mask(attention_mask) # type: ignore
+
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
diff --git a/src/adapters/utils.py b/src/adapters/utils.py
index cca9ee0e6b..0e3b20cabe 100644
--- a/src/adapters/utils.py
+++ b/src/adapters/utils.py
@@ -21,6 +21,8 @@
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
+import torch
+
import requests
from filelock import FileLock
from huggingface_hub import HfApi, HfFolder, snapshot_download
@@ -36,6 +38,7 @@
from transformers.utils.hub import torch_cache_home
from . import __version__
+from .context import ForwardContext
logger = logging.getLogger(__name__)
@@ -287,7 +290,6 @@ def get_from_cache(
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):
-
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
@@ -819,3 +821,44 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf
return None
else:
raise ValueError("Please specify either 'ah' or 'hf' as source.")
+
+
+def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
+ """
+ Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length`
+ attribute in the ForwardContext.
+
+ Args:
+ attention_mask:
+ The attention mask to add the prefix to.
+ dim (int):
+ The dimension along which to concatenate the prefix_attention_mask. Defaults to 3.
+ prefix_value (int):
+ The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use
+ different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not
+ masked tokens. This inversion is usually done in the forward method of the model in 2 different ways:
+ 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT
+ does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min`
+ """
+
+ forward_context = ForwardContext.get_context()
+
+ if (
+ attention_mask is not None
+ and forward_context is not None
+ and getattr(forward_context, "prompt_tokens_length", None) is not None
+ ):
+ # Create a tensor of ones with the desired shape
+ ones_shape = list(attention_mask.shape)
+ ones_shape[dim] = forward_context.prompt_tokens_length
+
+ prefix_attention_mask = torch.full(
+ ones_shape,
+ prefix_value,
+ dtype=attention_mask.dtype,
+ ).to(attention_mask.device)
+
+ # Concatenate the prefix_attention_mask along the specified dimension
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)
+
+ return attention_mask
diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py
index 42e1f64c1b..e1922dc5d7 100644
--- a/tests_adapters/composition/test_adapter_composition.py
+++ b/tests_adapters/composition/test_adapter_composition.py
@@ -5,7 +5,7 @@
import adapters
from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig
from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
-from tests.test_modeling_common import ids_tensor
+from tests_adapters.test_adapter import ids_tensor
from transformers import BertConfig, BertForSequenceClassification
from transformers.testing_utils import require_torch, torch_device
diff --git a/tests_adapters/methods/__init__.py b/tests_adapters/methods/__init__.py
index f40a688e58..b1cbe52de4 100644
--- a/tests_adapters/methods/__init__.py
+++ b/tests_adapters/methods/__init__.py
@@ -22,4 +22,5 @@
from .test_ia3 import IA3TestMixin
from .test_lora import LoRATestMixin
from .test_prefix_tuning import PrefixTuningTestMixin
+from .test_prompt_tuning import PromptTuningTestMixin
from .test_unipelt import UniPELTTestMixin
diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py
index 616e6a99e8..5c543dadca 100644
--- a/tests_adapters/methods/test_adapter_common.py
+++ b/tests_adapters/methods/test_adapter_common.py
@@ -27,7 +27,6 @@
@require_torch
class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin):
-
adapter_configs_to_test = [
(SeqBnConfig(), ["adapters.{name}."]),
(MAMConfig(), ["adapters.{name}.", "prefix_tunings.{name}."]),
@@ -211,6 +210,14 @@ def test_adapter_forward(self):
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
self.run_forward_test(model, adapter_config)
+ def test_invertible_adapter_forward(self):
+ model = self.get_model()
+ model.eval()
+
+ for adapter_config, _ in self.inv_adapter_configs_to_test:
+ with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
+ self.run_forward_test(model, adapter_config)
+
def test_load_adapter(self):
self.run_load_test(SeqBnConfig())
diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py
new file mode 100644
index 0000000000..d0b12d259c
--- /dev/null
+++ b/tests_adapters/methods/test_prompt_tuning.py
@@ -0,0 +1,36 @@
+from adapters import PromptTuningConfig
+from transformers.testing_utils import require_torch
+
+from .base import AdapterMethodBaseTestMixin
+
+
+@require_torch
+class PromptTuningTestMixin(AdapterMethodBaseTestMixin):
+ def test_add_prompt_tuning(self):
+ model = self.get_model()
+ self.run_add_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."])
+
+ def test_average_prompt_tuning(self):
+ model = self.get_model()
+ self.run_average_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."])
+
+ def test_delete_prompt_tuning(self):
+ model = self.get_model()
+ self.run_delete_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."])
+
+ def test_get_prompt_tuning(self):
+ model = self.get_model()
+ self.run_get_test(model, PromptTuningConfig(prompt_length=10), 1)
+
+ def test_forward_prompt_tuning(self):
+ model = self.get_model()
+ self.run_forward_test(model, PromptTuningConfig(prompt_length=10))
+
+ def test_load_prompt_tuning(self):
+ self.run_load_test(PromptTuningConfig(prompt_length=10))
+
+ def test_load_full_model_prompt_tuning(self):
+ self.run_full_model_load_test(PromptTuningConfig(prompt_length=10))
+
+ def test_train_prompt_tuning(self):
+ self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."])
diff --git a/tests_adapters/test_adapter.py b/tests_adapters/test_adapter.py
index d5e8d2754c..84c45c67dc 100644
--- a/tests_adapters/test_adapter.py
+++ b/tests_adapters/test_adapter.py
@@ -9,10 +9,29 @@
from transformers.testing_utils import torch_device
+global_rng = random.Random()
+
+
def make_config(config_class, **kwargs):
return staticmethod(lambda: config_class(**kwargs))
+def ids_tensor(shape, vocab_size, rng=None, name=None):
+ # Creates a random int32 tensor of the shape within the vocab size
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.randint(0, vocab_size - 1))
+
+ return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
+
+
class AdapterTestBase:
# If not overriden by subclass, AutoModel should be used.
model_class = AutoAdapterModel
diff --git a/tests_adapters/test_adapter_config.py b/tests_adapters/test_adapter_config.py
index 5667213683..2bce31c7e3 100644
--- a/tests_adapters/test_adapter_config.py
+++ b/tests_adapters/test_adapter_config.py
@@ -31,6 +31,7 @@ def test_config_immutable(self):
def set_attr(config: AdapterConfig):
config.non_linearity = "dummy"
config.r = -1 # for LoRA
+ config.prompt_length = -1 # for PromptTuning
for config in ADAPTER_CONFIG_MAP.values():
if isinstance(config, ConfigUnion):
diff --git a/tests_adapters/test_adapter_custom_head.py b/tests_adapters/test_adapter_custom_head.py
index ea37b52325..b68662bfc6 100644
--- a/tests_adapters/test_adapter_custom_head.py
+++ b/tests_adapters/test_adapter_custom_head.py
@@ -5,10 +5,11 @@
from adapters import AutoAdapterModel
from adapters.heads import ClassificationHead, PredictionHead
-from tests.test_modeling_common import ids_tensor
from transformers import AutoConfig
from transformers.testing_utils import require_torch, torch_device
+from .test_adapter import ids_tensor
+
class CustomHead(PredictionHead):
def __init__(
diff --git a/tests_adapters/test_adapter_hub.py b/tests_adapters/test_adapter_hub.py
index 28a47136aa..7ebf9acb6c 100644
--- a/tests_adapters/test_adapter_hub.py
+++ b/tests_adapters/test_adapter_hub.py
@@ -7,7 +7,6 @@
from adapters import ADAPTER_CONFIG_MAP, AdapterConfig, BertAdapterModel, get_adapter_config_hash
from adapters.trainer import AdapterTrainer as Trainer
from adapters.utils import find_in_index
-from tests.test_modeling_common import ids_tensor
from transformers import ( # get_adapter_config_hash,
AutoModel,
AutoTokenizer,
@@ -19,6 +18,8 @@
)
from transformers.testing_utils import require_torch, torch_device
+from .test_adapter import ids_tensor
+
SAMPLE_INDEX = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/hub-index.sample.json")
diff --git a/tests_adapters/test_albert.py b/tests_adapters/test_albert.py
index 29f8a2b583..054dd31278 100644
--- a/tests_adapters/test_albert.py
+++ b/tests_adapters/test_albert.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -42,6 +43,7 @@ class AlbertAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_beit.py b/tests_adapters/test_beit.py
index b2014c6c00..1e83c9e529 100644
--- a/tests_adapters/test_beit.py
+++ b/tests_adapters/test_beit.py
@@ -9,6 +9,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import VisionAdapterTestBase, make_config
@@ -38,6 +39,7 @@ class BeitAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
diff --git a/tests_adapters/test_bert.py b/tests_adapters/test_bert.py
index 702e68ab9d..b4e67bc811 100644
--- a/tests_adapters/test_bert.py
+++ b/tests_adapters/test_bert.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -40,6 +41,7 @@ class BertAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_bert_generation.py b/tests_adapters/test_bert_generation.py
index feb821ca0b..44cbd25f8e 100644
--- a/tests_adapters/test_bert_generation.py
+++ b/tests_adapters/test_bert_generation.py
@@ -12,6 +12,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -84,6 +85,7 @@ class BertGenerationAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_deberta.py b/tests_adapters/test_deberta.py
index 96be88d26a..7fd80322ec 100644
--- a/tests_adapters/test_deberta.py
+++ b/tests_adapters/test_deberta.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -47,6 +48,7 @@ class DebertaAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
ParallelTrainingMixin,
diff --git a/tests_adapters/test_debertaV2.py b/tests_adapters/test_debertaV2.py
index bc436d996d..b2d564c2e4 100644
--- a/tests_adapters/test_debertaV2.py
+++ b/tests_adapters/test_debertaV2.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -47,6 +48,7 @@ class DebertaV2AdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
ParallelTrainingMixin,
diff --git a/tests_adapters/test_distilbert.py b/tests_adapters/test_distilbert.py
index d401fe220b..2634b390a5 100644
--- a/tests_adapters/test_distilbert.py
+++ b/tests_adapters/test_distilbert.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -40,6 +41,7 @@ class DistilBertAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
CompabilityTestMixin,
diff --git a/tests_adapters/test_electra.py b/tests_adapters/test_electra.py
index 5566e7be0d..a5c005f509 100644
--- a/tests_adapters/test_electra.py
+++ b/tests_adapters/test_electra.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -41,6 +42,7 @@ class ElectraAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
EmbeddingTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_roberta.py b/tests_adapters/test_roberta.py
index 5ccbb53eee..6d105ceaec 100644
--- a/tests_adapters/test_roberta.py
+++ b/tests_adapters/test_roberta.py
@@ -11,6 +11,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -40,6 +41,7 @@ class RobertaAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
diff --git a/tests_adapters/test_vit.py b/tests_adapters/test_vit.py
index d84d8523ca..2de1b34300 100644
--- a/tests_adapters/test_vit.py
+++ b/tests_adapters/test_vit.py
@@ -10,6 +10,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import VisionAdapterTestBase, make_config
@@ -39,6 +40,7 @@ class ViTAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
diff --git a/tests_adapters/test_xlm_roberta.py b/tests_adapters/test_xlm_roberta.py
index f46d9543ac..96268302f7 100644
--- a/tests_adapters/test_xlm_roberta.py
+++ b/tests_adapters/test_xlm_roberta.py
@@ -9,6 +9,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -36,6 +37,7 @@ class XLMRobertaAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
AdapterFusionModelTestMixin,
XLMRobertaAdapterTestBase,
diff --git a/tests_adapters/test_xmod.py b/tests_adapters/test_xmod.py
index 450c84231d..a8cd02c4d9 100644
--- a/tests_adapters/test_xmod.py
+++ b/tests_adapters/test_xmod.py
@@ -10,6 +10,7 @@
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
@@ -41,6 +42,7 @@ class XmodAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ PromptTuningTestMixin,
UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,