diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md
index 05f85f3fdb..e8b6bf3c10 100644
--- a/docs/adapter_composition.md
+++ b/docs/adapter_composition.md
@@ -42,14 +42,16 @@ The following table gives an overview on the supported composition blocks and th
| Block | Bottleneck
Adapters | Prefix
Tuning | Compacter | LoRA | (IA)³ |
| --- | --- | --- | --- | --- | --- |
-| [`Stack`](#stack) | ✅ | ✅ | ✅ | | |
+| [`Stack`](#stack) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) |
| [`Fuse`](#fuse) | ✅ | | ✅ | | |
| [`Split`](#split) | ✅ | | ✅ | | |
-| [`BatchSplit`](#batchsplit) | ✅ | ✅ | ✅ | | |
-| [`Parallel`](#parallel) | ✅ | ✅ | ✅ | | |
-| [Output averaging](#output-averaging) | ✅ | | ✅ | | |
+| [`BatchSplit`](#batchsplit) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) |
+| [`Parallel`](#parallel) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) |
+| [Output averaging](#output-averaging) | ✅ | | ✅ | ✅(*) | ✅(*) |
| [Parameter averaging](#parameter-averaging) | ✅ | ✅ | ✅ | ✅ | ✅ |
+(*) except for Deberta-v1, GPT-2.
+
Next, we present all composition blocks in more detail.
## `Stack`
diff --git a/docs/index.rst b/docs/index.rst
index fdddf228ec..b3685d8d28 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -94,7 +94,6 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/adapter_config
classes/model_adapters_config
- classes/adapter_modules
classes/adapter_layer
classes/model_mixins
classes/adapter_training
diff --git a/src/adapters/composition.py b/src/adapters/composition.py
index 5899b113d6..937fae2685 100644
--- a/src/adapters/composition.py
+++ b/src/adapters/composition.py
@@ -1,6 +1,8 @@
import itertools
from collections.abc import Sequence
-from typing import List, Optional, Set, Union
+from typing import List, Optional, Set, Tuple, Union
+
+import torch
class AdapterCompositionBlock(Sequence):
@@ -242,3 +244,16 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors):
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
tensor.set_(new_tensor)
+
+
+def match_attn_matrices_for_parallel(query, key, value) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Matches the shapes of query, key and value matrices for parallel composition.
+ """
+ max_bsz = max(query.shape[0], key.shape[0], value.shape[0])
+
+ query = query.repeat(max_bsz // query.shape[0], *([1] * len(query.shape[1:])))
+ key = key.repeat(max_bsz // key.shape[0], *([1] * len(key.shape[1:])))
+ value = value.repeat(max_bsz // value.shape[0], *([1] * len(value.shape[1:])))
+
+ return query, key, value
diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py
index b89b75cb14..79d18500ec 100644
--- a/src/adapters/methods/adapter_layer_base.py
+++ b/src/adapters/methods/adapter_layer_base.py
@@ -150,10 +150,13 @@ class ComposableAdapterLayerBase(AdapterLayerBase):
Base class for all adapter methods that support composition.
Make sure the 'adapter_modules_name' and 'supported_compositions' attributes as well as all abstract methods are
- overriden in derived classes.
+ overriden in derived classes. 'allow_multi_parallelize' can be set to True to allow inputs to be parallelized
+ independently multiple times. This is useful when there are multiple parallel input flows through an adapter layer
+ (e.g. in LoRA).
"""
supported_compositions = []
+ allow_multi_parallelize = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -382,15 +385,23 @@ def compose_parallel(self, adapter_setup: Parallel, state: NamedTuple, lvl: int
orig_batch_size = self._bsz(state)
state = self.repeat(state, adapter_setup.parallel_channels)
context.adapters_parallelized = True
+ context.original_batch_size = orig_batch_size
else:
+ bsz = self._bsz(state)
+ # If the input was already parallelized, we can parallelize it again.
+ # This is useful e.g. for LoRA, where attention matrices are parallelized independently.
+ if self.allow_multi_parallelize and bsz == getattr(context, "original_batch_size", -1):
+ state = self.repeat(state, adapter_setup.parallel_channels)
+ orig_batch_size = bsz
# The base model should handle replication of input.
# Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels.
- if self._bsz(state) % adapter_setup.parallel_channels != 0:
+ elif bsz % adapter_setup.parallel_channels != 0:
raise ValueError(
"The total input batch size in a Parallel adapter block must be divisible by the number of"
" parallel channels."
)
- orig_batch_size = self._bsz(state) // adapter_setup.parallel_channels
+ else:
+ orig_batch_size = bsz // adapter_setup.parallel_channels
state = self.pre_block(adapter_setup, state)
diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py
index a4c66c830c..db987a7853 100644
--- a/src/adapters/methods/lora.py
+++ b/src/adapters/methods/lora.py
@@ -3,8 +3,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
+import logging
import math
-from typing import Dict, List, Union
+from typing import Dict, List, NamedTuple, Optional, Union
import torch
import torch.nn as nn
@@ -13,9 +14,12 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.pytorch_utils import Conv1D
-from ..composition import AdapterCompositionBlock
+from ..composition import AdapterCompositionBlock, Average, BatchSplit, Parallel, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
-from .adapter_layer_base import AdapterLayerBase
+from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
+
+
+logger = logging.getLogger(__name__)
class LoRA(nn.Module):
@@ -27,6 +31,7 @@ def __init__(
gating_heads: int = 1,
):
super().__init__()
+ assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'."
self.r = config.r
self.lora_alpha = config.alpha
self.composition_mode = config.composition_mode
@@ -39,58 +44,126 @@ def __init__(
self.lora_dropout = lambda x: x
# Actual trainable parameters
- if self.r > 1 and self.composition_mode == "scale":
- raise ValueError("Can only use composition_mode='scale' when r == 1.")
- if self.r > 0:
- if self.composition_mode == "add":
- self.lora_A = nn.Parameter(torch.zeros(lora_A_shape))
- self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
- self.scaling = self.lora_alpha / self.r
-
- if self.use_gating:
- self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
-
- if config.init_weights == "lora":
- # initialize A the same way as the default for nn.Linear and B to zero
- if self.composition_mode == "add":
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- if self.use_gating:
- nn.init.normal_(self.gate.weight, std=0.02)
- elif config.init_weights == "bert":
- if self.composition_mode == "add":
- nn.init.normal_(self.lora_A, std=0.02)
- nn.init.normal_(self.lora_B, std=0.02)
- if self.use_gating:
- nn.init.normal_(self.gate.weight, std=0.02)
- elif config.init_weights == "ia3":
- if self.composition_mode == "add":
- nn.init.ones_(self.lora_A)
- nn.init.ones_(self.lora_B)
- if self.use_gating:
- nn.init.normal_(self.gate.weight, std=0.02)
- else:
- raise ValueError("Unknown init_weights type: {}".format(config.init_weights))
+ self.lora_A = nn.Parameter(torch.zeros(lora_A_shape))
+ self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
+ self.scaling = self.lora_alpha / self.r
+
+ # For compatibility with (IA)^3, allow all init_weights types here.
+ # Usually should be "lora".
+ if config.init_weights == "lora":
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+ elif config.init_weights == "bert":
+ nn.init.normal_(self.lora_A, std=0.02)
+ nn.init.normal_(self.lora_B, std=0.02)
+ elif config.init_weights == "ia3":
+ nn.init.ones_(self.lora_A)
+ nn.init.ones_(self.lora_B)
+ else:
+ raise ValueError("Unknown init_weights type: {}".format(config.init_weights))
+
+ if self.use_gating:
+ self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
+ nn.init.normal_(self.gate.weight, std=0.02)
+
+ @property
+ def delta_w(self) -> torch.Tensor:
+ return self.lora_B @ self.lora_A
def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
"""Performs the composition operation between existing and injected weights."""
if scaling is None:
scaling = self.scaling
- if self.composition_mode == "add":
- return weights + added * scaling
- elif self.composition_mode == "scale":
- return weights * (added * scaling)
+ return weights + added * scaling
+
+ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
+ """Inverts the composition operation between existing and injected weights."""
+ return weights - added * self.scaling
+
+ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
+ if hidden_states is None:
+ hidden_states = layer_input
+ hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
+ if self.use_gating:
+ gate = torch.sigmoid(self.gate(layer_input))
+ gate = torch.mean(gate, dim=1).unsqueeze(-1)
+ hidden_states = hidden_states * gate
+ else:
+ gate = None
+
+ return hidden_states, gate
+
+
+class IA3(nn.Module):
+ def __init__(
+ self,
+ lora_A_shape,
+ lora_B_shape,
+ config: LoRAConfig,
+ gating_heads: int = 1,
+ ):
+ super().__init__()
+ assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'."
+ if config.r > 1:
+ raise ValueError("Can only use composition_mode='scale' when r == 1.")
+ self.r = config.r
+ self.lora_alpha = config.alpha
+ self.composition_mode = config.composition_mode
+ self.attn_matrices = config.attn_matrices
+ self.use_gating = config.use_gating
+ # Optional dropout
+ if config.dropout > 0.0:
+ raise ValueError("IA3 module does not support dropout.")
+
+ # Actual trainable parameters
+ self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
+ self.scaling = self.lora_alpha
+
+ # For compatibility with LoRA, allow all init_weights types here.
+ # Usually should be "ia3".
+ if config.init_weights == "lora":
+ logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.")
+ nn.init.zeros_(self.lora_B)
+ elif config.init_weights == "bert":
+ nn.init.normal_(self.lora_B, std=0.02)
+ elif config.init_weights == "ia3":
+ nn.init.ones_(self.lora_B)
else:
- raise ValueError("Invalid composition mode.")
+ raise ValueError("Unknown init_weights type: {}".format(config.init_weights))
+
+ if self.use_gating:
+ self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
+ nn.init.normal_(self.gate.weight, std=0.02)
+
+ @property
+ def delta_w(self) -> torch.Tensor:
+ return self.lora_B
+
+ def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
+ """Performs the composition operation between existing and injected weights."""
+ if scaling is None:
+ scaling = self.scaling
+ return weights * (added * scaling)
def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
- if self.composition_mode == "add":
- return weights - added * self.scaling
- elif self.composition_mode == "scale":
- return weights / (added * self.scaling)
+ return weights / (added * self.scaling)
+
+ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
+ scaling_vector = self.lora_B.view(1, 1, -1).repeat(layer_input.shape[0], 1, 1)
+ if hidden_states is None:
+ hidden_states = scaling_vector
+ else:
+ hidden_states = hidden_states * scaling_vector
+ if self.use_gating:
+ gate = torch.sigmoid(self.gate(layer_input))
+ gate = torch.mean(gate, dim=1).unsqueeze(-1)
+ hidden_states = hidden_states * gate
else:
- raise ValueError("Invalid composition mode.")
+ gate = None
+
+ return hidden_states, gate
class LoRALayer(AdapterLayerBase):
@@ -107,7 +180,7 @@ def __init__(
self.merged = False
- def get_n_heads(self, lora: Union[LoRA, LoRAConfig]):
+ def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]):
return 1
def _check_lora_location(self, config: LoRAConfig):
@@ -125,7 +198,13 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
location_key=self.location_key,
)
if lora_config is not None and self._check_lora_location(lora_config):
- lora = LoRA(
+ if lora_config.composition_mode == "add":
+ lora_cls = LoRA
+ elif lora_config.composition_mode == "scale":
+ lora_cls = IA3
+ else:
+ raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}")
+ lora = lora_cls(
*self._get_lora_shapes(lora_config),
lora_config,
gating_heads=self.get_n_heads(lora_config),
@@ -188,9 +267,25 @@ def get_adapter(self, adapter_name: str) -> nn.Module:
return None
-class Linear(LoRALayer, nn.Linear):
+class LoRAState(NamedTuple):
+ """Models the input and output states of a LoRA layer.
+
+ Args:
+ layer_input (torch.Tensor): The input states to the adapted layer.
+ hidden_states (Optional[torch.Tensor]):
+ The hidden states of the adaptation module. These can be None before passing through the first LoRA/ IA3
+ module.
+ layer_output (torch.Tensor): The output states of the original layer without adaptation.
"""
- LoRA implementation for Linear layer.
+
+ layer_input: torch.Tensor
+ hidden_states: Optional[torch.Tensor]
+ layer_output: torch.Tensor
+
+
+class LoRALinear(LoRALayer, ComposableAdapterLayerBase, nn.Linear):
+ """
+ LoRA implementation for Linear layer. This layer supports composition.
Args:
fan_in_fan_out (bool, optional):
@@ -199,6 +294,9 @@ class Linear(LoRALayer, nn.Linear):
"""
+ supported_compositions = [Stack, BatchSplit, Average, Parallel]
+ allow_multi_parallelize = True
+
def __init__(
self,
in_features: int,
@@ -267,36 +365,17 @@ def _check_lora_location(self, config: LoRAConfig):
def _get_lora_shapes(self, config: LoRAConfig):
return (config.r, self.in_features), (self.out_features, config.r)
- def reset_adapter(self):
- def T(w):
- return torch.t(w) if self.fan_in_fan_out else w
+ def maybe_t(self, w):
+ return torch.t(w) if self.fan_in_fan_out else w
+ def reset_adapter(self):
if self.merged:
lora = self.loras[self.merged]
# Make sure that the weights are not merged
- if lora.r > 0:
- if lora.composition_mode == "scale":
- delta_w = T(lora.lora_B)
- else:
- delta_w = T(lora.lora_B @ lora.lora_A)
- self.weight.data = lora.com_inv(self.weight.data, delta_w)
+ delta_w = self.maybe_t(lora.delta_w)
+ self.weight.data = lora.com_inv(self.weight.data, delta_w)
self.merged = None
- def _compute_adapted_weight(self, lora, scaling=None):
- def T(w):
- return torch.t(w) if self.fan_in_fan_out else w
-
- weight = self.weight
- # Merge the weights and mark it
- if lora.r > 0:
- if lora.composition_mode == "scale":
- delta_w = T(lora.lora_B)
- else:
- delta_w = T(lora.lora_B @ lora.lora_A)
- weight = lora.com(weight, delta_w, scaling=scaling)
-
- return weight
-
def merge_adapter(self, name: str):
if name in self.loras:
if self.merged == name:
@@ -305,44 +384,74 @@ def merge_adapter(self, name: str):
lora = self.loras[name]
if lora.use_gating:
raise ValueError("Cannot merge LoRA layer with gating.")
- self.weight.data = self._compute_adapted_weight(lora)
+ delta_w = self.maybe_t(lora.delta_w)
+ self.weight.data = lora.com(self.weight.data, delta_w)
self.merged = name
elif self.merged != name:
raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.")
- def forward(self, x: torch.Tensor):
- def T(w):
- return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w
+ def vslice(self, state: LoRAState, slice_obj: slice) -> LoRAState:
+ return LoRAState(
+ state.layer_input[slice_obj],
+ state.hidden_states[slice_obj] if state.hidden_states is not None else None,
+ state.layer_output[slice_obj],
+ )
+
+ def pad_and_concat(self, states: List[LoRAState]) -> LoRAState:
+ return LoRAState(
+ torch.cat([s.layer_input for s in states], dim=0),
+ torch.cat([s.hidden_states for s in states], dim=0) if states[0].hidden_states is not None else None,
+ torch.cat([s.layer_output for s in states], dim=0),
+ )
+
+ def repeat(self, state: LoRAState, channels: int) -> LoRAState:
+ return LoRAState(
+ state.layer_input.repeat(channels, 1, 1),
+ state.hidden_states.repeat(channels, 1, 1) if state.hidden_states is not None else None,
+ state.layer_output.repeat(channels, 1, 1),
+ )
+
+ def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState:
+ return LoRAState(
+ states[0].layer_input,
+ torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0)
+ if states[0].hidden_states is not None
+ else None,
+ states[0].layer_output,
+ )
+
+ def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) -> LoRAState:
+ lora = self.loras[adapter_setup]
+ hidden_states, gate = lora(state.hidden_states, state.layer_input)
+ if gate is not None:
+ self._store_gating_score(adapter_setup, gate)
+
+ return state._replace(hidden_states=hidden_states)
+
+ def forward(self, input_states: torch.Tensor):
+ weight = torch.transpose(self.weight, -2, -1) if self.fan_in_fan_out else self.weight
+ # result shape: x x
+ layer_output = F.linear(input_states, weight, bias=self.bias)
if not self.merged:
adapter_setup = self.get_active_setup()
if adapter_setup is not None:
- if len(adapter_setup) == 1:
- lora = self.loras[adapter_setup[0]]
- # result shape: x x
- result = F.linear(x, T(self.weight), bias=self.bias)
- if lora.r > 0:
- if lora.composition_mode == "scale":
- delta_w = lora.lora_B.view(1, 1, -1)
- else:
- delta_w = lora.lora_dropout(x) @ torch.t(lora.lora_A) @ torch.t(lora.lora_B)
- if lora.use_gating:
- gate = torch.sigmoid(lora.gate(x))
- gate = torch.mean(gate, dim=1).unsqueeze(-1)
- self._store_gating_score(adapter_setup[0], gate)
- else:
- gate = None
- result = lora.com(result, delta_w, scaling=gate)
- return result
- else:
- raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")
+ state = LoRAState(input_states, None, layer_output)
+ state = self.compose(adapter_setup, state)
+ _, hidden_states, layer_output = state
- return F.linear(x, T(self.weight), bias=self.bias)
+ last_lora = self.loras[adapter_setup.last()]
+ layer_output = last_lora.com(
+ layer_output, hidden_states, scaling=1.0
+ ) # scaling already applied in compose
+
+ return layer_output
-class MergedLinear(LoRALayer, nn.Linear):
+class LoRAMergedLinear(LoRALayer, nn.Linear):
"""
- LoRA implementation for merged attention layer layer.
+ LoRA implementation for merged attention layer, as used by some model implementations (e.g. GPT-2). This layer
+ currently does not support composition.
Args:
fan_in_fan_out (bool, optional):
@@ -395,7 +504,7 @@ def wrap(
return new_module
- def get_n_heads(self, lora: Union[LoRA, LoRAConfig]):
+ def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]):
return len(set(lora.attn_matrices))
def _get_lora_shapes(self, config: LoRAConfig):
diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py
index 21534980af..ff9ef19fe3 100644
--- a/src/adapters/models/albert/mixin_albert.py
+++ b/src/adapters/models/albert/mixin_albert.py
@@ -4,7 +4,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py
index df3e7523f0..7f5294cad7 100644
--- a/src/adapters/models/albert/modeling_albert.py
+++ b/src/adapters/models/albert/modeling_albert.py
@@ -23,7 +23,7 @@
from transformers.models.albert.modeling_albert import AlbertAttention, AlbertLayer
from transformers.pytorch_utils import apply_chunking_to_forward
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin
@@ -42,6 +42,8 @@ def forward(
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask
diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py
index 5ef20aaa86..28e7b3ac77 100644
--- a/src/adapters/models/bart/mixin_bart.py
+++ b/src/adapters/models/bart/mixin_bart.py
@@ -5,7 +5,7 @@
from ...composition import adjust_tensors_for_parallel
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py
index cb15b385bd..28bf37bd7c 100644
--- a/src/adapters/models/bart/modeling_bart.py
+++ b/src/adapters/models/bart/modeling_bart.py
@@ -21,7 +21,7 @@
from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer
-from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
+from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin
@@ -74,6 +74,11 @@ def forward(
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ query_states, key_states, value_states = match_attn_matrices_for_parallel(
+ query_states, key_states, value_states
+ )
+ (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
+
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py
index 2c129f085c..536048e669 100644
--- a/src/adapters/models/beit/mixin_beit.py
+++ b/src/adapters/models/beit/mixin_beit.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py
index e97c9dd988..3cf5a6e1ff 100644
--- a/src/adapters/models/bert/mixin_bert.py
+++ b/src/adapters/models/bert/mixin_bert.py
@@ -5,7 +5,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py
index 539dc74ebf..692605610a 100644
--- a/src/adapters/models/bert/modeling_bert.py
+++ b/src/adapters/models/bert/modeling_bert.py
@@ -25,7 +25,7 @@
from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -66,6 +66,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py
index 8f083fe295..8381ccf2bb 100644
--- a/src/adapters/models/bert_generation/modeling_bert_generation.py
+++ b/src/adapters/models/bert_generation/modeling_bert_generation.py
@@ -27,7 +27,7 @@
BertGenerationSelfOutput,
)
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -78,6 +78,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py
index 36eae84b0f..02469974f5 100644
--- a/src/adapters/models/clip/mixin_clip.py
+++ b/src/adapters/models/clip/mixin_clip.py
@@ -4,7 +4,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py
index cee8530f02..d9907de36d 100644
--- a/src/adapters/models/deberta/mixin_deberta.py
+++ b/src/adapters/models/deberta/mixin_deberta.py
@@ -1,4 +1,4 @@
-from ...methods.lora import MergedLinear as LoRAMergedLinear
+from ...methods.lora import LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py
index 8197c19fb6..71b7f9dc2a 100644
--- a/src/adapters/models/deberta/modeling_deberta.py
+++ b/src/adapters/models/deberta/modeling_deberta.py
@@ -24,7 +24,7 @@
XSoftmax,
)
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin
from .mixin_deberta import DebertaSelfAttentionAdaptersMixin
@@ -113,6 +113,9 @@ def linear(w, b, x):
k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
+
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py
index f60e8788fb..3a33fdf84c 100644
--- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py
+++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py
@@ -1,4 +1,4 @@
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py
index 082e77a721..aa8945000f 100644
--- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py
+++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py
@@ -24,7 +24,7 @@
XSoftmax,
)
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin
from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin
@@ -97,6 +97,9 @@ def forward(
key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads)
value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
+
orig_key_layer = key_layer.contiguous() # save this for relative attention
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask, False
diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py
index 44bcbb0b16..111733c2f0 100644
--- a/src/adapters/models/distilbert/mixin_distilbert.py
+++ b/src/adapters/models/distilbert/mixin_distilbert.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py
index 6dfb62eb1c..e0aee4e1b9 100644
--- a/src/adapters/models/distilbert/modeling_distilbert.py
+++ b/src/adapters/models/distilbert/modeling_distilbert.py
@@ -27,7 +27,7 @@
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock
-from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
+from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin
@@ -70,6 +70,9 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
+ q, k, v = match_attn_matrices_for_parallel(q, k, v)
+ (mask,) = adjust_tensors_for_parallel(q, mask)
+
k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False)
bs = k.size(0) # reset for Parallel block
(q,) = adjust_tensors_for_parallel(k, q)
diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py
index 35552782ce..cbe4277ec9 100644
--- a/src/adapters/models/electra/modeling_electra.py
+++ b/src/adapters/models/electra/modeling_electra.py
@@ -6,7 +6,7 @@
from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -47,6 +47,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py
index e86c2967a9..ce88136a92 100644
--- a/src/adapters/models/gpt2/mixin_gpt2.py
+++ b/src/adapters/models/gpt2/mixin_gpt2.py
@@ -3,8 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
-from ...methods.lora import MergedLinear as LoRAMergedLinear
+from ...methods.lora import LoRALinear, LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py
index 333c1b9358..7e4e771cba 100644
--- a/src/adapters/models/gptj/mixin_gptj.py
+++ b/src/adapters/models/gptj/mixin_gptj.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/gptj/modeling_gptj.py b/src/adapters/models/gptj/modeling_gptj.py
index 453f0c9b6d..700e919a17 100644
--- a/src/adapters/models/gptj/modeling_gptj.py
+++ b/src/adapters/models/gptj/modeling_gptj.py
@@ -22,7 +22,7 @@
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, apply_rotary_pos_emb, get_embed_positions
from transformers.utils.import_utils import is_torch_fx_proxy
-from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
+from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from .mixin_gptj import GPTJAttentionAdaptersMixin, GPTJDecoderBlockAdaptersMixin
@@ -44,6 +44,9 @@ def forward(
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
+ query, key, value = match_attn_matrices_for_parallel(query, key, value)
+ (attention_mask,) = adjust_tensors_for_parallel(query, attention_mask)
+
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py
index 22223edaf4..3caf66e544 100644
--- a/src/adapters/models/llama/mixin_llama.py
+++ b/src/adapters/models/llama/mixin_llama.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py
index f16b65e9c6..3b22e5ae13 100644
--- a/src/adapters/models/llama/modeling_llama.py
+++ b/src/adapters/models/llama/modeling_llama.py
@@ -25,7 +25,11 @@
import torch.utils.checkpoint
from torch import nn
-from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
+from adapters.composition import (
+ adjust_tensors_for_parallel,
+ adjust_tensors_for_parallel_,
+ match_attn_matrices_for_parallel,
+)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import logging
@@ -53,6 +57,11 @@ def forward(
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states, key_states, value_states = match_attn_matrices_for_parallel(
+ query_states, key_states, value_states
+ )
+ (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
+
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py
index 5c43212a28..0f8f0d5335 100644
--- a/src/adapters/models/mbart/modeling_mbart.py
+++ b/src/adapters/models/mbart/modeling_mbart.py
@@ -21,7 +21,7 @@
from transformers.models.mbart.modeling_mbart import MBartAttention, MBartDecoderLayer, MBartEncoderLayer
-from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
+from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from ..bart.mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin
@@ -74,6 +74,11 @@ def forward(
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ query_states, key_states, value_states = match_attn_matrices_for_parallel(
+ query_states, key_states, value_states
+ )
+ (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
+
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py
index 47a8ed35a9..e33b7e7ca3 100644
--- a/src/adapters/models/roberta/modeling_roberta.py
+++ b/src/adapters/models/roberta/modeling_roberta.py
@@ -24,7 +24,7 @@
from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -66,6 +66,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py
index 832dfd185d..244f5d4335 100644
--- a/src/adapters/models/t5/mixin_t5.py
+++ b/src/adapters/models/t5/mixin_t5.py
@@ -4,7 +4,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py
index 3440a4bb73..19064f58b2 100644
--- a/src/adapters/models/t5/modeling_t5.py
+++ b/src/adapters/models/t5/modeling_t5.py
@@ -28,7 +28,7 @@
)
from transformers.utils import logging
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from .mixin_t5 import (
T5AttentionAdaptersMixin,
T5CrossAttentionLayerAdaptersMixin,
@@ -128,6 +128,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
+ query_states, key_states, value_states = match_attn_matrices_for_parallel(
+ query_states, key_states, value_states
+ )
+ (mask,) = adjust_tensors_for_parallel(query_states, mask)
+
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask)
diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py
index 07598ad8ae..2f9962a9d8 100644
--- a/src/adapters/models/vit/mixin_vit.py
+++ b/src/adapters/models/vit/mixin_vit.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from ...methods.bottleneck import BottleneckLayer
-from ...methods.lora import Linear as LoRALinear
+from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py
index bb0fadd2ca..f8c02bd931 100644
--- a/src/adapters/models/vit/modeling_vit.py
+++ b/src/adapters/models/vit/modeling_vit.py
@@ -22,7 +22,7 @@
import torch.utils.checkpoint
from torch import nn
-from adapters.composition import adjust_tensors_for_parallel
+from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSelfAttention
from .mixin_vit import ViTLayerAdaptersMixin, ViTOutputAdaptersMixin, ViTSelfAttentionAdaptersMixin
@@ -38,6 +38,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+
key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)
diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
index a8d22284b7..5f18c9f70e 100644
--- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
+++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
@@ -28,7 +28,7 @@
XLMRobertaSelfOutput,
)
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -70,6 +70,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py
index b772321667..4a2269fbae 100644
--- a/src/adapters/models/xmod/modeling_xmod.py
+++ b/src/adapters/models/xmod/modeling_xmod.py
@@ -23,7 +23,7 @@
from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput
-from ...composition import adjust_tensors_for_parallel
+from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin
@@ -65,6 +65,8 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
+ query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
+ (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
use_cache = past_key_value is not None
if self.is_decoder:
diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py
index 2670488cb9..42e1f64c1b 100644
--- a/tests_adapters/composition/test_adapter_composition.py
+++ b/tests_adapters/composition/test_adapter_composition.py
@@ -3,7 +3,7 @@
import torch
import adapters
-from adapters import PrefixTuningConfig, SeqBnConfig
+from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig
from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
from tests.test_modeling_common import ids_tensor
from transformers import BertConfig, BertForSequenceClassification
@@ -140,9 +140,9 @@ def test_parallel(self):
model.set_active_adapters(Parallel("a", "b", "c", "d"))
inputs = {}
- inputs["input_ids"] = ids_tensor((1, 128), 1000)
+ inputs["input_ids"] = ids_tensor((2, 10), 1000)
logits = model(**inputs).logits
- self.assertEqual(logits.shape, (4, 2))
+ self.assertEqual(logits.shape, (8, 2))
def test_nested_parallel(self):
if Parallel in self.unsupported_blocks or Stack in self.unsupported_blocks:
@@ -152,7 +152,7 @@ def test_nested_parallel(self):
model.set_active_adapters(Stack("a", Parallel(Stack("b", "c"), "d")))
inputs = {}
- inputs["input_ids"] = ids_tensor((1, 128), 1000)
+ inputs["input_ids"] = ids_tensor((1, 10), 1000)
logits = model(**inputs).logits
self.assertEqual(logits.shape, (2, 2))
@@ -234,3 +234,17 @@ class PrefixTuningCompositionTest(AdapterCompositionTest):
def get_adapter_config(self):
return PrefixTuningConfig()
+
+
+class LoRACompositionTest(AdapterCompositionTest):
+ unsupported_blocks = [Split, Fuse]
+
+ def get_adapter_config(self):
+ return LoRAConfig(init_weights="bert")
+
+
+class IA3CompositionTest(AdapterCompositionTest):
+ unsupported_blocks = [Split, Fuse]
+
+ def get_adapter_config(self):
+ return IA3Config()