Skip to content

Commit

Permalink
Add ReFT (LoReFT, NoReFT, DiReFT) (#705)
Browse files Browse the repository at this point in the history
This PR integrates multiple ReFT variants as new adapter methods.

Paper: https://arxiv.org/pdf/2404.03592
Original code: https://github.com/stanfordnlp/pyreft

## Changes

- Add ReFT module implementation via `ReftLayer`, integrated into all
models supported by Adapters. Integration via `init_reft()` method &
Pytorch hook.
- Add new `ReftConfig` as base config class with three default
instances: `LoReftConfig`, `NoReftConfig` and `DiReftConfig`.
- Method documentation can be found here:
https://github.com/adapter-hub/adapters/blob/6c19ea06c143621a735226e477bf772068e55be3/docs/methods.md#reft

## Compatibility

Tested that Pyreft & Adapters produce the same outputs on inference by
converting Pyreft checkpoints to Adapters checkpoints (tested settings:
LoReft, NoReft, DiReft, weight tying, prefix, suffix, rank, mostly using
roberta-base).

Script for testing & checkpoint conversion here:
https://github.com/calpt/pyreft/blob/main/compatibility.py.

## Evaluation

Roberta-base with LoReFT on GLUE, using hyperparameters similar to the
paper:

Task | Score
--- | ---
Cola (Matthews Corr.) | 53.95
MNLI (Acc.) | 83.23
MRPC (F1) | 91.70
QNLI (Acc.) | 90.94
QQP (Acc.) | 86.82
RTE (Acc.) | 76.53
SST-2 (Acc.) | 93.81
STS-B (Spearmanr) | 88.99

## Todos

- [x] Modeling implementations
- [x] Add test methods
- [x] Make all checks passing
- [x] Add documentation
- [x] Make sure implementation produces same outputs as original code
- [x] Sanity check training runs
  • Loading branch information
calpt authored Jul 1, 2024
1 parent ac74998 commit d8c991f
Show file tree
Hide file tree
Showing 39 changed files with 616 additions and 315 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Currently, adapters integrates all architectures and methods listed below:
| UniPELT | [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) | [Docs](https://docs.adapterhub.ml/method_combinations.html#unipelt) |
| Prompt Tuning | [Lester et al. (2021)](https://aclanthology.org/2021.emnlp-main.243/) | [Docs](https://docs.adapterhub.ml/methods.html#prompt-tuning) |
| QLoRA | [Dettmers et al. (2023)](https://arxiv.org/pdf/2305.14314.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) |
| ReFT | [Wu et al. (2024)](https://arxiv.org/pdf/2404.03592) | [Docs](https://docs.adapterhub.ml/methods.html#reft) |

## Supported Models

Expand Down
17 changes: 17 additions & 0 deletions docs/classes/adapter_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ PromptTuningConfig
:members:
:inherited-members: Mapping


ReFT
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.ReftConfig
:members:
:inherited-members: Mapping

.. autoclass:: adapters.LoReftConfig
:members:

.. autoclass:: adapters.NoReftConfig
:members:

.. autoclass:: adapters.DiReftConfig
:members:

Combined configurations
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
47 changes: 47 additions & 0 deletions docs/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,50 @@ model.add_adapter("dummy", config=config)
_Papers:_
- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021)

## ReFT

_Configuration class_: [`ReftConfig`](adapters.ReftConfig)

Representation Fine-Tuning (ReFT), as first proposed by [Wu et al. (2024)](https://arxiv.org/pdf/2404.03592), leverages so-called interventions to adapt the pre-trained representations of a language model.
Within the context of ReFT, these interventions can intuitively be thought of as adapter modules placed after each Transformer layer.
In the general form, an intervention function $\Phi$ can thus be defined as follows:

$$
\Phi(h) = h + R^T (W h + b - R h)
$$

Here, $R \in \mathbb{R}^{r \times d}$ and $W \in \mathbb{R}^{r \times d}$ are low-rank matrices of rank $r$.
$h$ is the layer output hidden state at a single sequence position, i.e. interventions can be applied independently at each position.

Based on this general form, the ReFT paper proposes multiple instantiations of ReFT methods supported by _Adapters_:

- **LoReFT** enforces orthogonality of rows in $R$. Defined via [`LoReftConfig`](adapters.LoReftConfig) or via the `orthogonality` attribute as in the following example:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=True
) # equivalent to LoreftConfig()
```

- **NoReFT** does not enforce orthogonality in $R$. Defined via [`NoReftConfig`](adapters.NoReftConfig) or equivalently:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=False
) # equivalent to NoreftConfig()
```

- **DiReFT** does not enforce orthogonality in $R$ and additionally removes subtraction of $R h$ in the intervention, Defined via [`DiReftConfig`](adapters.DiReftConfig) or equivalently:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=False, subtract_projection=False
) # equivalent to DireftConfig()
```

In addition, _Adapters_ supports configuring multiple hyperparameters tuned in the ReFT paper in `ReftConfig`, including:
- `prefix_positions`: number of prefix positions
- `suffix_positions`: number of suffix positions
- `layers`: The layers to intervene on. This can either be `"all"` or a list of layer ids
- `tied_weights`: whether to tie parameters between prefixes and suffixes

_Papers:_

* [ReFT: Representation Finetuning for Language Models](https://arxiv.org/pdf/2404.03592) (Wu et al., 2024)
44 changes: 22 additions & 22 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```

| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning |
| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning | ReFT |
| --------------------------------------- | -| - | - | - | - | - | - |- |
| [ALBERT](classes/models/albert.html) |||||||||
| [BART](classes/models/bart.html) |||||||| |
| [BEIT](classes/models/beit.html) |||||| | ||
| [BERT-Generation](classes/models/bert-generation.html) |||||||||
| [BERT](classes/models/bert.html) |||||||||
| [CLIP](classes/models/clip.html) ||||||| | |
| [DeBERTa](classes/models/deberta.html) |||||||||
| [DeBERTa-v2](classes/models/debertaV2.html) |||||||||
| [DistilBERT](classes/models/distilbert.html) |||||||||
| [Electra](classes/models/electra.html) |||||||||
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | |
| [GPT-2](classes/models/gpt2.html) |||||||| |
| [GPT-J](classes/models/gptj.html) |||||||| |
| [Llama](classes/models/llama.html) |||||||| |
| [MBart](classes/models/mbart.html) |||||||| |
| [MT5](classes/models/mt5.html) |||||||| |
| [RoBERTa](classes/models/roberta.html) |||||||||
| [T5](classes/models/t5.html) |||||||| |
| [ViT](classes/models/vit.html) |||||||||
| [XLM-RoBERTa](classes/models/xlmroberta.html) |||||||||
| [X-MOD](classes/models/xmod.html) |||||||||
| [ALBERT](classes/models/albert.html) ||||||||||
| [BART](classes/models/bart.html) |||||||| ||
| [BEIT](classes/models/beit.html) |||||| | |||
| [BERT-Generation](classes/models/bert-generation.html) ||||||||||
| [BERT](classes/models/bert.html) ||||||||||
| [CLIP](classes/models/clip.html) ||||||| | ||
| [DeBERTa](classes/models/deberta.html) ||||||||||
| [DeBERTa-v2](classes/models/debertaV2.html) ||||||||||
| [DistilBERT](classes/models/distilbert.html) ||||||||||
| [Electra](classes/models/electra.html) ||||||||||
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | (*) |
| [GPT-2](classes/models/gpt2.html) |||||||| ||
| [GPT-J](classes/models/gptj.html) |||||||| ||
| [Llama](classes/models/llama.html) |||||||| ||
| [MBart](classes/models/mbart.html) |||||||| ||
| [MT5](classes/models/mt5.html) |||||||| ||
| [RoBERTa](classes/models/roberta.html) ||||||||||
| [T5](classes/models/t5.html) |||||||| ||
| [ViT](classes/models/vit.html) ||||||||||
| [XLM-RoBERTa](classes/models/xlmroberta.html) ||||||||||
| [X-MOD](classes/models/xmod.html) ||||||||||

(*) If the used encoder and decoder model class are supported.

Expand Down
5 changes: 4 additions & 1 deletion docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ Identifiers and configuration classes are explained in more detail in the [next
| `ia3` | `IA3Config()` | [IA³](methods.html#ia-3) |
| `mam` | `MAMConfig()` | [Mix-and-Match Adapters](method_combinations.html#mix-and-match-adapters) |
| `unipelt` | `UniPELTConfig()` | [UniPELT](method_combinations.html#unipelt) |
| `prompt_tuning` | `PromptTuningConfig()` | [Prompt Tuning](methods.html#prompt-tuning)
| `prompt_tuning` | `PromptTuningConfig()` | [Prompt Tuning](methods.html#prompt-tuning) |
| `loreft` | `LoReftConfig()` | [ReFT](methods.html#reft) |
| `noreft` | `NoReftConfig()` | [ReFT](methods.html#reft) |
| `direft` | `DiReftConfig()` | [ReFT](methods.html#reft) |

## Configuration

Expand Down
8 changes: 8 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@
"CompacterConfig",
"CompacterPlusPlusConfig",
"ConfigUnion",
"DiReftConfig",
"DoubleSeqBnConfig",
"DoubleSeqBnInvConfig",
"DynamicAdapterFusionConfig",
"IA3Config",
"LoRAConfig",
"LoReftConfig",
"MAMConfig",
"ModelAdaptersConfig",
"NoReftConfig",
"ParBnConfig",
"PrefixTuningConfig",
"PromptTuningConfig",
"ReftConfig",
"SeqBnConfig",
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
Expand Down Expand Up @@ -154,16 +158,20 @@
CompacterConfig,
CompacterPlusPlusConfig,
ConfigUnion,
DiReftConfig,
DoubleSeqBnConfig,
DoubleSeqBnInvConfig,
DynamicAdapterFusionConfig,
IA3Config,
LoRAConfig,
LoReftConfig,
MAMConfig,
ModelAdaptersConfig,
NoReftConfig,
ParBnConfig,
PrefixTuningConfig,
PromptTuningConfig,
ReftConfig,
SeqBnConfig,
SeqBnInvConfig,
StaticAdapterFusionConfig,
Expand Down
84 changes: 83 additions & 1 deletion src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from dataclasses import FrozenInstanceError, asdict, dataclass, field, replace
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

from ..utils import resolve_adapter_config

Expand Down Expand Up @@ -86,6 +86,8 @@ def _get_config_class(config_dict):
cls_new = ConfigUnion
elif architecture == "prompt_tuning":
cls_new = PromptTuningConfig
elif architecture == "reft":
cls_new = ReftConfig
else:
cls_new = BnConfig

Expand Down Expand Up @@ -497,6 +499,83 @@ class IA3Config(LoRAConfig):
use_gating: bool = False


@dataclass(eq=False)
class ReftConfig(AdapterConfig):
"""
Base class for Representation Fine-Tuning (ReFT) methods proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
ReFT methods have in common that they add "interventions" after selected model layers and at selected sequence positions to adapt the representations produced by module outputs.
Args:
layers (Union[Literal["all"], List[int]]): The IDs of the layers where interventions should be added.
If "all", interventions are added after all layers (default).
prefix_positions (int): The number of prefix positions to add interventions to.
suffix_positions (int): The number of suffix positions to add interventions to.
r (int): The rank of the intervention layer.
orthogonality (bool): If True, enforce an orthogonality constraint for the projection matrix.
tied_weights (bool): If True, share intervention parameters between prefix and suffix positions in each layer.
subtract_projection (bool): If True, subtract the projection of the input.
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
"""

layers: Union[Literal["all"], List[int]]
prefix_positions: int
suffix_positions: int
r: int
orthogonality: bool
tied_weights: bool = False
subtract_projection = True
dropout: float = 0.05
non_linearity: Optional[str] = None

architecture: str = "reft"

output_reft: bool = True


@dataclass(eq=False)
class LoReftConfig(ReftConfig):
"""
Low-Rank Linear Subspace ReFT method proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = True
tied_weights: bool = False


@dataclass(eq=False)
class NoReftConfig(ReftConfig):
"""
Variation of LoReft without orthogonality constraint.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = False


@dataclass(eq=False)
class DiReftConfig(ReftConfig):
"""
Variation of LoReft without orthogonality constraint and projection subtraction as proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = False
subtract_projection = False


class ConfigUnion(AdapterConfig):
"""
Composes multiple adaptation method configurations into one. This class can be used to define complex adaptation
Expand Down Expand Up @@ -650,6 +729,9 @@ def __init__(
"prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"loreft": LoReftConfig(),
"noreft": NoReftConfig(),
"direft": DiReftConfig(),
"mam": MAMConfig(),
"unipelt": UniPELTConfig(),
}
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def filter_func(self, adapter_name):
or ".prefix_tunings.{}.".format(adapter_name) in x
or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
or ".refts.{}.".format(adapter_name) in x
or ".prompt_tunings.{}.".format(adapter_name) in x
)

Expand Down Expand Up @@ -393,6 +394,7 @@ def rename_func(self, old_name, new_name):
.replace(".prefix_tunings.{}.".format(old_name), ".prefix_tunings.{}.".format(new_name))
.replace(".prefix_gates.{}.".format(old_name), ".prefix_gates.{}.".format(new_name))
.replace(".loras.{}.".format(old_name), ".loras.{}.".format(new_name))
.replace(".refts.{}.".format(old_name), ".refts.{}.".format(new_name))
)

def save_to_state_dict(self, name: str):
Expand Down Expand Up @@ -446,6 +448,8 @@ def save(self, save_directory, name, meta_dict=None):

adapter_config = self.model.adapters_config.get(name)

self.model.apply_to_adapter_layers(lambda _, layer: layer.pre_save_adapters())

config_dict = build_full_config(
adapter_config,
self.model.config,
Expand Down
Loading

0 comments on commit d8c991f

Please sign in to comment.