Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor adapter composition implementation #591

Merged
merged 11 commits into from
Oct 29, 2023
10 changes: 5 additions & 5 deletions docs/adapter_composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ In the example, `attention_scores` holds a dictionary of the following form:
Splitting the input between two adapters using the 'Split' block.
```

The `Split` block can be used to split an input sequence between two adapters.
This is done by specifying a split index, at which the sequences should be divided.
The `Split` block can be used to split an input sequence between multiple adapters.
This is done by specifying split indices at which the sequences should be divided.
In the following example, we split each input sequence between adapters `g` and `h`.
For each sequence, all tokens from 0 up to 63 are forwarded through `g` while all tokens beginning at index 64 are forwarded through `h`:
For each sequence, all tokens from 0 up to 63 are forwarded through `g` while the next 64 tokens are forwarded through `h`:

```python
import adapters.composition as ac
Expand All @@ -173,7 +173,7 @@ import adapters.composition as ac
model.add_adapter("g")
model.add_adapter("h")

model.active_adapters = ac.Split("g", "h", split_index=64)
model.active_adapters = ac.Split("g", "h", splits=[64, 64])
```

## `BatchSplit`
Expand Down Expand Up @@ -286,7 +286,7 @@ E.g., we can nest a `Split` block within a `Stack` of adapters:
```python
import adapters.composition as ac

model.active_adapters = ac.Stack("a", ac.Split("b", "c", split_index=60))
model.active_adapters = ac.Stack("a", ac.Split("b", "c", splits=60))
```

However, combinations of adapter composition blocks cannot be arbitrarily deep. All currently supported possibilities are visualized in the table below.
Expand Down
4 changes: 2 additions & 2 deletions docs/classes/adapter_layer.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
AdapterLayer
BottleneckLayer
=======================

.. autoclass:: adapters.AdapterLayer
.. autoclass:: adapters.BottleneckLayer
:members:
2 changes: 1 addition & 1 deletion docs/contributing/adding_adapter_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Thus, each adapter method implementation at least should provide two classes:
including methods for adding, enabling and deleting adapter weights.
- Most importantly, the module classes deriving from this base class should implement the forward pass through an adaptation component.
- The concrete implementation of these classes heavily depends on the specifics of the adapter method.
For a reference implementation, have a look at `AdapterLayer` for bottleneck adapters.
For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters.
- To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations.
- This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see `src/transformers/adapters/mixins`) or directly as submodules of the respective model components.
- The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation.
Expand Down
4 changes: 2 additions & 2 deletions docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- To figure out which classes to change, think about where to insert LoRA, Prefix Tuning, and bottleneck adapters.
- You can use similar model implementations for guidance.
- Often, existing mixins of another class can be reused. E.g. `BertLayer`, `RobertaLayer`, `XLMRobertaLayer`, `DebertaLayer`, `DebertaV2Layer` and `BertGenerationLayer` (all models derived from BERT) use the `BertLayerAdaptersMixin`.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningShim` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `adapter_layer_forward()` are added in the right places.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningLayer` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `bottleneck_layer_forward()` are added in the right places.
- The mixin for the whole base model class (e.g., `BertModel`) should derive from `ModelBaseAdaptersMixin` and (if possible) `EmbeddingAdaptersMixin` and/or `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- If the model is a combination of different models, such as the EncoderDecoderModel, use `ModelUsingSubmodelsAdaptersMixin` instead of `ModelBaseAdaptersMixin`.
3. **Copied functions:**
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
"layer": ["AdapterLayer", "AdapterLayerBase"],
"methods.adapter_layer_base": ["AdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
"InvertibleAdaptersMixin",
Expand Down Expand Up @@ -182,7 +182,7 @@
Seq2SeqLMHead,
TaggingHead,
)
from .layer import AdapterLayer, AdapterLayerBase
from .methods.adapter_layer_base import AdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
Expand Down
9 changes: 3 additions & 6 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,9 @@ def name(self):


class Split(AdapterCompositionBlock):
def __init__(self, left: str, right: str, split_index: int):
super().__init__(left, right)
assert split_index > 0
self.left = left
self.right = right
self.split_index = split_index
def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], splits: Union[List[int], int]):
super().__init__(*split_adapters)
self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters)


class BatchSplit(AdapterCompositionBlock):
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..context import AdapterSetup, ForwardContext
from ..methods.modeling import Activation_Function_Class
from ..model_mixin import ModelWithHeadsAdaptersMixin
from ..modeling import Activation_Function_Class


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput

from ..modeling import Activation_Function_Class
from ..methods.modeling import Activation_Function_Class
from .base import PredictionHead


Expand Down
Loading