From 152992cf2e06587d52aabdb0a3fb0d489229941e Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 17 Sep 2023 15:44:58 +0200 Subject: [PATCH 01/10] Move adapter implementations to new folder & modules. AdapterLayer -> BottleneckLayer. --- docs/classes/adapter_layer.rst | 4 +- docs/contributing/adding_adapter_methods.md | 2 +- src/adapters/__init__.py | 4 +- src/adapters/heads/base.py | 2 +- src/adapters/heads/language_modeling.py | 2 +- src/adapters/methods/__init__.py | 0 src/adapters/methods/adapter_layer_base.py | 96 ++++++++++++++++ .../{layer.py => methods/bottleneck.py} | 103 ++---------------- src/adapters/{ => methods}/lora.py | 6 +- src/adapters/{ => methods}/modeling.py | 4 +- src/adapters/{ => methods}/prefix_tuning.py | 8 +- src/adapters/model_mixin.py | 11 +- src/adapters/models/albert/mixin_albert.py | 10 +- src/adapters/models/bart/mixin_bart.py | 12 +- src/adapters/models/beit/mixin_beit.py | 10 +- src/adapters/models/bert/mixin_bert.py | 14 +-- src/adapters/models/clip/mixin_clip.py | 10 +- src/adapters/models/deberta/mixin_deberta.py | 4 +- .../models/deberta_v2/mixin_deberta_v2.py | 4 +- .../models/distilbert/mixin_distilbert.py | 10 +- src/adapters/models/gpt2/mixin_gpt2.py | 12 +- src/adapters/models/gptj/mixin_gptj.py | 10 +- src/adapters/models/llama/mixin_llama.py | 10 +- src/adapters/models/t5/mixin_t5.py | 12 +- src/adapters/models/vit/mixin_vit.py | 10 +- 25 files changed, 189 insertions(+), 181 deletions(-) create mode 100644 src/adapters/methods/__init__.py create mode 100644 src/adapters/methods/adapter_layer_base.py rename src/adapters/{layer.py => methods/bottleneck.py} (88%) rename src/adapters/{ => methods}/lora.py (99%) rename src/adapters/{ => methods}/modeling.py (99%) rename src/adapters/{ => methods}/prefix_tuning.py (99%) diff --git a/docs/classes/adapter_layer.rst b/docs/classes/adapter_layer.rst index 2b54475994..d76d13dd52 100644 --- a/docs/classes/adapter_layer.rst +++ b/docs/classes/adapter_layer.rst @@ -1,5 +1,5 @@ -AdapterLayer +BottleneckLayer ======================= -.. autoclass:: adapters.AdapterLayer +.. autoclass:: adapters.BottleneckLayer :members: diff --git a/docs/contributing/adding_adapter_methods.md b/docs/contributing/adding_adapter_methods.md index e968c90b7c..de3f1937e3 100644 --- a/docs/contributing/adding_adapter_methods.md +++ b/docs/contributing/adding_adapter_methods.md @@ -32,7 +32,7 @@ Thus, each adapter method implementation at least should provide two classes: including methods for adding, enabling and deleting adapter weights. - Most importantly, the module classes deriving from this base class should implement the forward pass through an adaptation component. - The concrete implementation of these classes heavily depends on the specifics of the adapter method. - For a reference implementation, have a look at `AdapterLayer` for bottleneck adapters. + For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters. - To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations. - This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see `src/transformers/adapters/mixins`) or directly as submodules of the respective model components. - The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation. diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index dd68037057..768e4ef082 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -78,7 +78,7 @@ "Seq2SeqLMHead", "TaggingHead", ], - "layer": ["AdapterLayer", "AdapterLayerBase"], + "methods.adapter_layer_base": ["AdapterLayerBase"], "model_mixin": [ "EmbeddingAdaptersMixin", "InvertibleAdaptersMixin", @@ -182,7 +182,7 @@ Seq2SeqLMHead, TaggingHead, ) - from .layer import AdapterLayer, AdapterLayerBase + from .methods.adapter_layer_base import AdapterLayerBase from .model_mixin import ( EmbeddingAdaptersMixin, InvertibleAdaptersMixin, diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index dd43a4e658..d9c7386fec 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -20,8 +20,8 @@ from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition from ..context import AdapterSetup, ForwardContext +from ..methods.modeling import Activation_Function_Class from ..model_mixin import ModelWithHeadsAdaptersMixin -from ..modeling import Activation_Function_Class logger = logging.getLogger(__name__) diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py index 7e6ec95ccb..3e0cda610a 100644 --- a/src/adapters/heads/language_modeling.py +++ b/src/adapters/heads/language_modeling.py @@ -2,7 +2,7 @@ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput -from ..modeling import Activation_Function_Class +from ..methods.modeling import Activation_Function_Class from .base import PredictionHead diff --git a/src/adapters/methods/__init__.py b/src/adapters/methods/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py new file mode 100644 index 0000000000..e6a95256c6 --- /dev/null +++ b/src/adapters/methods/adapter_layer_base.py @@ -0,0 +1,96 @@ +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Union + +import numpy as np +from torch import nn + +from ..composition import AdapterCompositionBlock +from ..context import AdapterSetup, ForwardContext + + +# We don't inherit from ABC because __slots__ changes object layout +class AdapterLayerBase(metaclass=ABCMeta): + """ + Base class for all adaptation methods that require per-layer modules. + """ + + @property + def layer_idx(self): + return getattr(self, "_layer_idx", -1) + + @layer_idx.setter + def layer_idx(self, layer_idx): + idx = getattr(self, "_layer_idx", layer_idx) + assert idx == layer_idx + setattr(self, "_layer_idx", idx) + + def get_active_setup(self, module_dict): + if hasattr(self, "adapters_config"): + # First check current context before falling back to defined setup + context = AdapterSetup.get_context() + if context is not None: + adapter_setup = context.adapter_setup + else: + adapter_setup = self.adapters_config.active_setup + else: + adapter_setup = None + skip_adapters = adapter_setup is None or ( + self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers + ) + if not skip_adapters and (len(set(module_dict.keys()) & adapter_setup.flatten()) > 0): + return adapter_setup + else: + return None + + def _store_gating_score(self, adapter_name, gating_score): + context = ForwardContext.get_context() + if context.output_adapter_gating_scores: + gating_cache = context.adapter_gating_scores + if self.layer_idx not in gating_cache[adapter_name]: + gating_cache[adapter_name][self.layer_idx] = {} + gating_score = gating_score.detach().squeeze().cpu().numpy() + if len(gating_score.shape) == 0: + gating_score = np.expand_dims(gating_score, axis=0) + cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None) + if cache_score is not None: + gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack( + (cache_score, gating_score) + ) + else: + gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score + + def _store_fusion_attentions(self, fusion_name, attentions): + context = ForwardContext.get_context() + if context.output_adapter_fusion_attentions: + attention_cache = context.adapter_fusion_attentions + if self.layer_idx not in attention_cache[fusion_name]: + attention_cache[fusion_name][self.layer_idx] = {} + attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions + + @abstractmethod + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + raise NotImplementedError() + + @abstractmethod + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + raise NotImplementedError() + + @abstractmethod + def delete_adapter(self, adapter_name: str): + raise NotImplementedError() + + @abstractmethod + def add_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() + + @abstractmethod + def delete_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() + + @abstractmethod + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + raise NotImplementedError() + + @abstractmethod + def get_adapter(self, adapter_name: str) -> nn.Module: + raise NotImplementedError() diff --git a/src/adapters/layer.py b/src/adapters/methods/bottleneck.py similarity index 88% rename from src/adapters/layer.py rename to src/adapters/methods/bottleneck.py index 32b57915a5..d3c1ab8f8c 100644 --- a/src/adapters/layer.py +++ b/src/adapters/methods/bottleneck.py @@ -1,11 +1,9 @@ -from abc import ABCMeta, abstractmethod from typing import Dict, List, Mapping, Union -import numpy as np import torch from torch import nn -from .composition import ( +from ..composition import ( AdapterCompositionBlock, Average, BatchSplit, @@ -15,100 +13,13 @@ Stack, adjust_tensors_for_parallel, ) -from .configuration import BnConfig -from .context import AdapterSetup, ForwardContext +from ..configuration import BnConfig +from ..context import ForwardContext +from .adapter_layer_base import AdapterLayerBase from .modeling import Adapter, BertFusion, ParallelAdapter -# We don't inherit from ABC because __slots__ changes object layout -class AdapterLayerBase(metaclass=ABCMeta): - """ - Base class for all adaptation methods that require per-layer modules. - """ - - @property - def layer_idx(self): - return getattr(self, "_layer_idx", -1) - - @layer_idx.setter - def layer_idx(self, layer_idx): - idx = getattr(self, "_layer_idx", layer_idx) - assert idx == layer_idx - setattr(self, "_layer_idx", idx) - - def get_active_setup(self, module_dict): - if hasattr(self, "adapters_config"): - # First check current context before falling back to defined setup - context = AdapterSetup.get_context() - if context is not None: - adapter_setup = context.adapter_setup - else: - adapter_setup = self.adapters_config.active_setup - else: - adapter_setup = None - skip_adapters = adapter_setup is None or ( - self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers - ) - if not skip_adapters and (len(set(module_dict.keys()) & adapter_setup.flatten()) > 0): - return adapter_setup - else: - return None - - def _store_gating_score(self, adapter_name, gating_score): - context = ForwardContext.get_context() - if context.output_adapter_gating_scores: - gating_cache = context.adapter_gating_scores - if self.layer_idx not in gating_cache[adapter_name]: - gating_cache[adapter_name][self.layer_idx] = {} - gating_score = gating_score.detach().squeeze().cpu().numpy() - if len(gating_score.shape) == 0: - gating_score = np.expand_dims(gating_score, axis=0) - cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None) - if cache_score is not None: - gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack( - (cache_score, gating_score) - ) - else: - gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score - - def _store_fusion_attentions(self, fusion_name, attentions): - context = ForwardContext.get_context() - if context.output_adapter_fusion_attentions: - attention_cache = context.adapter_fusion_attentions - if self.layer_idx not in attention_cache[fusion_name]: - attention_cache[fusion_name][self.layer_idx] = {} - attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions - - @abstractmethod - def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: - raise NotImplementedError() - - @abstractmethod - def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - raise NotImplementedError() - - @abstractmethod - def delete_adapter(self, adapter_name: str): - raise NotImplementedError() - - @abstractmethod - def add_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() - - @abstractmethod - def delete_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() - - @abstractmethod - def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): - raise NotImplementedError() - - @abstractmethod - def get_adapter(self, adapter_name: str) -> nn.Module: - raise NotImplementedError() - - -class AdapterLayer(AdapterLayerBase, nn.Module): +class BottleneckLayer(AdapterLayerBase, nn.Module): def __init__(self, location_key: str): super().__init__() self.location_key = location_key @@ -638,8 +549,8 @@ def adapter_average_output(self, adapter_setup: Average, hidden_states, input_te def adapter_layer_forward(self, hidden_states, residual_input, layer_norm): """Forward pass through the adapter layer. - NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise, - call the regular forward() method. + NOTE: This method should only be called if the calling module directly inherits from BottleneckLayer. + Otherwise, call the regular forward() method. Args: hidden_states (torch.Tensor): Input hidden states to the adapter layer. diff --git a/src/adapters/lora.py b/src/adapters/methods/lora.py similarity index 99% rename from src/adapters/lora.py rename to src/adapters/methods/lora.py index 3549e7a8fc..977fe8ae88 100644 --- a/src/adapters/lora.py +++ b/src/adapters/methods/lora.py @@ -13,9 +13,9 @@ from transformers.configuration_utils import PretrainedConfig from transformers.pytorch_utils import Conv1D -from .composition import AdapterCompositionBlock -from .configuration import LoRAConfig, ModelAdaptersConfig -from .layer import AdapterLayerBase +from ..composition import AdapterCompositionBlock +from ..configuration import LoRAConfig, ModelAdaptersConfig +from .adapter_layer_base import AdapterLayerBase class LoRA(nn.Module): diff --git a/src/adapters/modeling.py b/src/adapters/methods/modeling.py similarity index 99% rename from src/adapters/modeling.py rename to src/adapters/methods/modeling.py index b61419069e..6b265e21f2 100644 --- a/src/adapters/modeling.py +++ b/src/adapters/methods/modeling.py @@ -5,8 +5,8 @@ from transformers.activations import get_activation -from .configuration import AdapterFusionConfig, BnConfig -from .context import ForwardContext +from ..configuration import AdapterFusionConfig, BnConfig +from ..context import ForwardContext class Activation_Function_Class(nn.Module): diff --git a/src/adapters/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py similarity index 99% rename from src/adapters/prefix_tuning.py rename to src/adapters/methods/prefix_tuning.py index af9c57e03f..06ebd83f50 100644 --- a/src/adapters/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -7,10 +7,10 @@ from transformers import PretrainedConfig from transformers.modeling_utils import ModuleUtilsMixin -from .composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel -from .configuration import ModelAdaptersConfig, PrefixTuningConfig -from .context import AdapterSetup, ForwardContext -from .layer import AdapterLayerBase +from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel +from ..configuration import ModelAdaptersConfig, PrefixTuningConfig +from ..context import AdapterSetup, ForwardContext +from .adapter_layer_base import AdapterLayerBase from .modeling import Activation_Function_Class diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 33b57230da..7e6652eab0 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -16,11 +16,12 @@ from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin -from .layer import AdapterLayer, AdapterLayerBase from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader -from .lora import LoRALayer -from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters -from .prefix_tuning import PrefixTuningPool, PrefixTuningShim +from .methods.adapter_layer_base import AdapterLayerBase +from .methods.bottleneck import BottleneckLayer +from .methods.lora import LoRALayer +from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters +from .methods.prefix_tuning import PrefixTuningPool, PrefixTuningShim from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -933,7 +934,7 @@ def get_fusion_regularization_loss(self): target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) for i, layer in self.iter_layers(): for module in layer.modules(): - if isinstance(module, AdapterLayer): + if isinstance(module, BottleneckLayer): for _, layer_fusion in module.adapter_fusion_layer.items(): if hasattr(layer_fusion, "value") and layer_fusion.value.weight.requires_grad: layer_reg_loss = 0.01 * (target - layer_fusion.value.weight).pow(2).sum() diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index fa2340687e..45beebafb6 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -3,10 +3,10 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class AlbertAttentionAdaptersMixin: @@ -18,7 +18,7 @@ def init_adapters(self, model_config, adapters_config): self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v") - self.attention_adapters = AdapterLayer("mh_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") self.prefix_tuning = PrefixTuningShim( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config @@ -36,7 +36,7 @@ def init_adapters(self, model_config, adapters_config): # Set location keys for prefix tuning self.location_key = "output_adapter" - self.output_adapters = AdapterLayer("output_adapter") + self.output_adapters = BottleneckLayer("output_adapter") self.attention.location_key = "self" diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index e050d66940..32da93aea0 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -4,8 +4,9 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import ( EmbeddingAdaptersMixin, EmbeddingAdaptersWrapperMixin, @@ -13,7 +14,6 @@ InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class BartAttentionAdaptersMixin: @@ -40,8 +40,8 @@ def init_adapters(self, model_config, adapters_config): # Set attention layer location key for prefix tuning self.self_attn.location_key = "encoder" - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin): @@ -52,7 +52,7 @@ def init_adapters(self, model_config, adapters_config): # Set attention layer location key for prefix tuning self.self_attn.location_key = "self" self.encoder_attn.location_key = "cross" - self.cross_attention_adapters = AdapterLayer("cross_adapter") + self.cross_attention_adapters = BottleneckLayer("cross_adapter") class BartEncoderAdaptersMixin(InvertibleAdaptersMixin): diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index a54611507d..eb36a7ff6c 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class BeitSelfAttentionAdaptersMixin: @@ -38,8 +38,8 @@ class BeitLayerAdaptersMixin: """Adds adapters to the BeitLayer module.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class BeitModelAdaptersMixin(ModelBaseAdaptersMixin): diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index 2d715fb993..1ae358b451 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -4,10 +4,10 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim logger = logging.getLogger(__name__) @@ -27,8 +27,8 @@ def init_adapters(self, model_config, adapters_config): ) -# For backwards compatibility, BertSelfOutput inherits directly from AdapterLayer -class BertSelfOutputAdaptersMixin(AdapterLayer): +# For backwards compatibility, BertSelfOutput inherits directly from BottleneckLayer +class BertSelfOutputAdaptersMixin(BottleneckLayer): """Adds adapters to the BertSelfOutput module.""" def __init__(self): @@ -39,8 +39,8 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -# For backwards compatibility, BertOutput inherits directly from AdapterLayer -class BertOutputAdaptersMixin(AdapterLayer): +# For backwards compatibility, BertOutput inherits directly from BottleneckLayer +class BertOutputAdaptersMixin(BottleneckLayer): """Adds adapters to the BertOutput module.""" def __init__(self): diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 3eb8c8bbb0..f9a09d20a5 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -3,8 +3,9 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import ( EmbeddingAdaptersMixin, EmbeddingAdaptersWrapperMixin, @@ -12,7 +13,6 @@ InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class CLIPAttentionAdaptersMixin: @@ -35,8 +35,8 @@ def init_adapters(self, model_config, adapters_config): self.mlp.fc1 = LoRALinear.wrap(self.mlp.fc1, "intermediate", model_config, adapters_config) self.mlp.fc2 = LoRALinear.wrap(self.mlp.fc2, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class CLIPEncoderAdaptersMixin: diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py index 0407e59a82..ffe698fb55 100644 --- a/src/adapters/models/deberta/mixin_deberta.py +++ b/src/adapters/models/deberta/mixin_deberta.py @@ -1,5 +1,5 @@ -from ...lora import MergedLinear as LoRAMergedLinear -from ...prefix_tuning import PrefixTuningShim +from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.prefix_tuning import PrefixTuningShim class DebertaSelfAttentionAdaptersMixin: diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py index 3b4e01aa2f..0213604695 100644 --- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py +++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py @@ -1,5 +1,5 @@ -from ...lora import Linear as LoRALinear -from ...prefix_tuning import PrefixTuningShim +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim class DebertaV2SelfAttentionAdaptersMixin: diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index 582543b765..12fc757da2 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class DistilBertMultiHeadSelfAttentionMixin: @@ -28,8 +28,8 @@ def init_adapters(self, model_config, adapters_config): self.ffn.lin1 = LoRALinear.wrap(self.ffn.lin1, "intermediate", model_config, adapters_config) self.ffn.lin2 = LoRALinear.wrap(self.ffn.lin2, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class DistilBertTransformerAdaptersMixin: diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index b3cbf12219..c6ded5fd39 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -2,11 +2,11 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear -from ...lora import MergedLinear as LoRAMergedLinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class GPT2AttentionAdaptersMixin: @@ -50,8 +50,8 @@ def init_adapters(self, model_config, adapters_config): no_init_bias=True, ) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index d05880fbbe..5816684955 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class GPTJAttentionAdaptersMixin: @@ -33,8 +33,8 @@ class GPTJDecoderBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of GPTJ.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 4627e02593..a927a63129 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class LlamaAttentionMixin: @@ -23,8 +23,8 @@ def init_adapters(self, model_config, adapters_config): self.mlp.down_proj = LoRALinear.wrap(self.mlp.down_proj, "intermediate", model_config, adapters_config) self.mlp.up_proj = LoRALinear.wrap(self.mlp.up_proj, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index 3917c78366..f7df0ead3f 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -2,8 +2,9 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import ( EmbeddingAdaptersMixin, InvertibleAdaptersMixin, @@ -11,7 +12,6 @@ ModelBaseAdaptersMixin, ModelWithHeadsAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class T5AttentionAdaptersMixin: @@ -28,7 +28,7 @@ def init_adapters(self, model_config, adapters_config): ) -class T5SelfAttentionLayerAdaptersMixin(AdapterLayer): +class T5SelfAttentionLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("mh_adapter", None) @@ -37,7 +37,7 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -class T5CrossAttentionLayerAdaptersMixin(AdapterLayer): +class T5CrossAttentionLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("cross_adapter", None) @@ -47,7 +47,7 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -class T5FFLayerAdaptersMixin(AdapterLayer): +class T5FFLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("output_adapter", None) diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py index 5b540245bd..4b602ef992 100644 --- a/src/adapters/models/vit/mixin_vit.py +++ b/src/adapters/models/vit/mixin_vit.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningShim from ...model_mixin import ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class ViTSelfAttentionAdaptersMixin: @@ -32,7 +32,7 @@ class ViTOutputAdaptersMixin: """Adds adapters to the ViTOutput module.""" def init_adapters(self, model_config, adapters_config): - self.output_adapters = AdapterLayer("output_adapter") + self.output_adapters = BottleneckLayer("output_adapter") # Wrap layers for LoRA self.dense = LoRALinear.wrap(self.dense, "output", model_config, adapters_config) @@ -43,7 +43,7 @@ class ViTLayerAdaptersMixin: """Adds adapters to the ViTSelfOutput module.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") class ViTModelAdaptersMixin(ModelBaseAdaptersMixin): From 63220354a48e4bfc50696e16489cddc313c94b53 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 8 Oct 2023 01:03:24 +0200 Subject: [PATCH 02/10] Move basic composition handling into `ComposableAdapterLayerBase` --- src/adapters/composition.py | 9 +- src/adapters/methods/adapter_layer_base.py | 340 ++++++++++- src/adapters/methods/bottleneck.py | 528 +++++------------- src/adapters/methods/lora.py | 6 +- src/adapters/methods/prefix_tuning.py | 449 ++++----------- .../composition/test_adapter_composition.py | 8 +- 6 files changed, 607 insertions(+), 733 deletions(-) diff --git a/src/adapters/composition.py b/src/adapters/composition.py index b4ecab8179..5899b113d6 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -73,12 +73,9 @@ def name(self): class Split(AdapterCompositionBlock): - def __init__(self, left: str, right: str, split_index: int): - super().__init__(left, right) - assert split_index > 0 - self.left = left - self.right = right - self.split_index = split_index + def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], splits: Union[List[int], int]): + super().__init__(*split_adapters) + self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters) class BatchSplit(AdapterCompositionBlock): diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index e6a95256c6..7a70a4d764 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -1,10 +1,11 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, List, Union +from typing import Collection, Dict, List, NamedTuple, Union import numpy as np +import torch from torch import nn -from ..composition import AdapterCompositionBlock +from ..composition import ALLOWED_NESTINGS, AdapterCompositionBlock, Average, BatchSplit, Fuse, Parallel, Split, Stack from ..context import AdapterSetup, ForwardContext @@ -12,8 +13,16 @@ class AdapterLayerBase(metaclass=ABCMeta): """ Base class for all adaptation methods that require per-layer modules. + + Make sure the 'adapter_modules_name' attribute is overriden in derived classes. """ + adapter_modules_name = "" + + @property + def adapter_modules(self) -> Collection: + return getattr(self, self.adapter_modules_name) + @property def layer_idx(self): return getattr(self, "_layer_idx", -1) @@ -24,7 +33,7 @@ def layer_idx(self, layer_idx): assert idx == layer_idx setattr(self, "_layer_idx", idx) - def get_active_setup(self, module_dict): + def get_active_setup(self): if hasattr(self, "adapters_config"): # First check current context before falling back to defined setup context = AdapterSetup.get_context() @@ -37,7 +46,7 @@ def get_active_setup(self, module_dict): skip_adapters = adapter_setup is None or ( self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers ) - if not skip_adapters and (len(set(module_dict.keys()) & adapter_setup.flatten()) > 0): + if not skip_adapters and (len(set(self.adapter_modules.keys()) & adapter_setup.flatten()) > 0): return adapter_setup else: return None @@ -94,3 +103,326 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt @abstractmethod def get_adapter(self, adapter_name: str) -> nn.Module: raise NotImplementedError() + + +class ComposableAdapterLayerBase(AdapterLayerBase): + """ + Base class for all adapter methods that support composition. + + Make sure the 'adapter_modules_name' and 'supported_compositions' attributes as well as all abstract methods are + overriden in derived classes. + """ + + supported_compositions = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.composition_to_func_map = { + Stack: self.compose_stack, + Fuse: self.compose_fuse, + Split: self.compose_split, + BatchSplit: self.compose_batch_split, + Parallel: self.compose_parallel, + Average: self.compose_average, + } + + # START CUSTOMIZABLE METHODS # + # The following methods should be implemented in derived classes. + + def _bsz(self, state: NamedTuple) -> int: + """ + Returns the batch size of the given state. + """ + return state[0].shape[0] + + def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: NamedTuple) -> NamedTuple: + """ + Optional state pre-processing method which is invoked before passing the state to the first child block of a + composition. By default, this method does not contain any logic. E.g. used for bottleneck adapters to implement + residuals and LNs. + + Args: + adapter_setup (Union[AdapterCompositionBlock, str]): The current composition or single adapter. + state (NamedTuple): The current state. + + Returns: + NamedTuple: The pre-processed state. + """ + return state + + @abstractmethod + def vslice(self, state: NamedTuple, slice_obj: slice) -> NamedTuple: + """Slices the given state along the batch size (vertical) dimension. + This is e.g. used by the BatchSplit and Parallel composition blocks. IMPORTANT: Has to be implemented by all + derived classes. + + Args: + state (NamedTuple): The state to be sliced. + slice_obj (slice): The slice object. + + Returns: + NamedTuple: The sliced state. + """ + raise NotImplementedError() + + @abstractmethod + def pad_and_concat(self, states: List[NamedTuple]) -> NamedTuple: + """Concatenates the given states along the batch size dimension. + Pads the states before concatenation if necessary. This is e.g. used by the BatchSplit and Parallel composition + blocks. IMPORTANT: Has to be implemented by all derived classes. + + Args: + states (List[NamedTuple]): The states to be concatenated. + + Returns: + NamedTuple: The concatenated state. + """ + raise NotImplementedError() + + @abstractmethod + def repeat(self, state: NamedTuple, channels: int) -> NamedTuple: + """Repeats the given state along the batch size dimension for the given number of times. + This is e.g. used by the Parallel composition block. IMPORTANT: Has to be implemented by all derived classes. + + Args: + state (NamedTuple): The state to be repeated. + channels (int): The number of times the state should be repeated. + + Returns: + NamedTuple: The repeated state. + """ + raise NotImplementedError() + + @abstractmethod + def mean(self, states: List[NamedTuple], weights: torch.Tensor) -> NamedTuple: + """Averages the given states along the batch size dimension by the given weights. + This is e.g. used by the Average composition block. IMPORTANT: Has to be implemented by all derived classes. + + Args: + states (List[NamedTuple]): The states to be averaged. + weights (torch.Tensor): The averaging weights. + + Returns: + NamedTuple: The averaged state. + """ + raise NotImplementedError() + + @abstractmethod + def compose_single(self, adapter_setup: str, state: NamedTuple, lvl: int = 0) -> NamedTuple: + """Forwards the given state through the given single adapter. + + Args: + adapter_setup (str): The name of the adapter. + state (NamedTuple): The state to be forwarded. + lvl (int, optional): The composition depth. Defaults to 0. + + Returns: + NamedTuple: The state after forwarding through the adapter. + """ + raise NotImplementedError() + + # END CUSTOMIZABLE METHODS # + + def check_composition_valid(self, parent: AdapterCompositionBlock, child: AdapterCompositionBlock, lvl: int): + """Checks whether the given composition is valid. + + Args: + parent (AdapterCompositionBlock): The parent composition block. + child (AdapterCompositionBlock): The child composition block. + lvl (int): The composition depth. + + Raises: + ValueError: If the composition is invalid. + """ + # Break if setup is too deep + if lvl >= 1: + raise ValueError( + "Specified adapter setup is too deep. Cannot have {} at level {}".format(child.__class__.__name__, lvl) + ) + elif type(child) not in ALLOWED_NESTINGS[type(parent)]: + raise ValueError( + "Cannot nest {} inside {}. Only the following nestings are allowed: {}".format( + child.__class__.__name__, + parent.__class__.__name__, + ", ".join([t.__name__ for t in ALLOWED_NESTINGS[type(parent)]]), + ) + ) + + def compose_stack(self, adapter_setup: Stack, state: NamedTuple, lvl: int = 0) -> NamedTuple: + """ + For sequentially stacking multiple adapters. + """ + for i, adapter_stack_layer in enumerate(adapter_setup): + if isinstance(adapter_stack_layer, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, adapter_stack_layer, lvl) + composition_func = self.composition_to_func_map[type(adapter_stack_layer)] + state = composition_func(adapter_stack_layer, state, lvl=lvl + 1) + elif adapter_stack_layer in self.adapter_modules: + state = self.pre_block(adapter_stack_layer, state) + state = self.compose_single(adapter_stack_layer, state, lvl=lvl + 1) + else: + raise ValueError( + "Invalid adapter setup: {} is not a valid adapter name or composition block.".format( + adapter_stack_layer.__class__.__name__ + ) + ) + + return state + + def compose_fuse(self, adapter_setup: Fuse, state: NamedTuple, lvl: int = 0): + """ + For fusing multiple adapters using adapter fusion. NOTE: This method has no default implementation. + """ + # Fuse is currently only applicable to bottleneck adapters, thus don't provide a default implementation + raise NotImplementedError() + + def compose_split(self, adapter_setup: Split, state: NamedTuple, lvl: int = 0): + """ + For splitting to multiple adapters along the sequence length dimension. NOTE: This method has no default + implementation. + """ + # Split is currently only applicable to bottleneck adapters, thus don't provide a default implementation + raise NotImplementedError() + + def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: int = 0): + """ + For splitting to multiple adapters along the batch size dimension. + """ + if sum(adapter_setup.batch_sizes) != self._bsz(state): + raise IndexError( + "The given batch has a size of {} which is not equal to the sum of batch_sizes {}".format( + self._bsz(state), adapter_setup.batch_sizes + ) + ) + + state = self.pre_block(adapter_setup, state) + + # sequentially feed different parts of the blown-up batch into different adapters + children_states = [] + for i, child in enumerate(adapter_setup): + # compute ids of sequences thet should be passed to the ith adapter + batch_idx = ( + sum(adapter_setup.batch_sizes[:i]), + sum(adapter_setup.batch_sizes[: i + 1]), + ) + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func( + child, + self.vslice(state, slice(*batch_idx)), + lvl=lvl + 1, + ) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single( + child, + self.vslice(state, slice(*batch_idx)), + lvl=lvl + 1, + ) + children_states.append(child_state) + else: + children_states.append(self.vslice(state, slice(*batch_idx))) + + # concatenate all outputs and return + state = self.pad_and_concat(children_states) + return state + + def compose_parallel(self, adapter_setup: Parallel, state: NamedTuple, lvl: int = 0): + """ + For parallel execution of the adapters on the same input. This means that the input is repeated N times before + feeding it to the adapters (where N is the number of adapters). + """ + + context = ForwardContext.get_context() + if not context.adapters_parallelized: + orig_batch_size = self._bsz(state) + state = self.repeat(state, adapter_setup.parallel_channels) + context.adapters_parallelized = True + else: + # The base model should handle replication of input. + # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. + if self._bsz(state) % adapter_setup.parallel_channels != 0: + raise ValueError( + "The total input batch size in a Parallel adapter block must be divisible by the number of" + " parallel channels." + ) + orig_batch_size = self._bsz(state) // adapter_setup.parallel_channels + + state = self.pre_block(adapter_setup, state) + + # sequentially feed different parts of the blown-up batch into different adapters + children_states = [] + for i, child in enumerate(adapter_setup): + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func( + child, + self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size)), + lvl=lvl + 1, + ) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single( + child, + self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size)), + lvl=lvl + 1, + ) + children_states.append(child_state) + else: + children_states.append(self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size))) + + # concatenate all outputs and return + state = self.pad_and_concat(children_states) + return state + + def compose_average(self, adapter_setup: Average, state: NamedTuple, lvl: int = 0): + """ + For averaging the output representations of multiple adapters. + """ + + state = self.pre_block(adapter_setup, state) + + children_states = [] + for i, child in enumerate(adapter_setup): + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, state, lvl=lvl + 1) + children_states.append(child_state) + else: + pass + + weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(state[0].device) + state = self.mean(children_states, weights) + + return state + + def compose(self, adapter_setup: Union[AdapterCompositionBlock, str], state: NamedTuple) -> NamedTuple: + """The main composition forward method which recursively calls the composition blocks forward methods. + This method should be called by the forward method of the derived class. + + Args: + adapter_setup (Union[AdapterCompositionBlock, str]): The adapter setup to be used. + state (NamedTuple): The current state. + + Returns: + NamedTuple: The state after forwarding through the adapter setup. + """ + if isinstance(adapter_setup, AdapterCompositionBlock): + composition_func = self.composition_to_func_map[type(adapter_setup)] + state = composition_func(adapter_setup, state, lvl=0) + elif adapter_setup in self.adapter_modules: + state = self.compose_single(adapter_setup, state, lvl=0) + else: + raise ValueError( + "Invalid adapter setup: {} is not a valid adapter name or composition block.".format( + adapter_setup.__class__.__name__ + ) + ) + + return state diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index d3c1ab8f8c..b802e960e2 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Mapping, Union +from typing import Dict, List, Mapping, NamedTuple, Optional, Union import torch from torch import nn @@ -15,11 +15,34 @@ ) from ..configuration import BnConfig from ..context import ForwardContext -from .adapter_layer_base import AdapterLayerBase +from .adapter_layer_base import ComposableAdapterLayerBase from .modeling import Adapter, BertFusion, ParallelAdapter -class BottleneckLayer(AdapterLayerBase, nn.Module): +class BottleneckState(NamedTuple): + """ + Models the input and output states of a bottleneck adapter layer. + + Args: + hidden_states (torch.Tensor): The layer input/ output hidden states. + input_tensor (torch.Tensor): The Transformer sub-block residual connection inputs. + adapter_residual (torch.Tensor): The adapter residual connection inputs. + layer_norm (torch.nn.Module, optional): The Transformer layer norm module. + bottleneck_up (torch.Tensor, optional): + The up-projected bottleneck MLP output. This is only for Fuse compositions. + """ + + hidden_states: torch.Tensor + input_tensor: torch.Tensor + adapter_residual: torch.Tensor + layer_norm: Optional[torch.nn.Module] + bottleneck_up: Optional[torch.Tensor] = None + + +class BottleneckLayer(ComposableAdapterLayerBase, nn.Module): + adapter_modules_name = "adapters" + supported_compositions = [Stack, Fuse, Split, Parallel, BatchSplit, Average] + def __init__(self, location_key: str): super().__init__() self.location_key = location_key @@ -150,64 +173,70 @@ def get_adapter(self, adapter_name: str): else: return None - def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer_norm, lvl=0): - """ - Forwards the given input through the given stack of adapters. - """ - for i, adapter_stack_layer in enumerate(adapter_setup): - # Break if setup is too deep - if isinstance(adapter_stack_layer, AdapterCompositionBlock) and lvl >= 1: - raise ValueError( - "Specified adapter setup is too deep. Cannot have {} at level {}".format( - adapter_stack_layer.__class__.__name__, lvl - ) - ) - # Case 1: We have a nested fusion layer -> call fusion method - if isinstance(adapter_stack_layer, Fuse): - hidden_states = self.adapter_fusion( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 2: We have a nested split layer -> call split method - elif isinstance(adapter_stack_layer, Split): - hidden_states = self.adapter_split( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 3: We have a nested parallel layer -> call parallel method - elif isinstance(adapter_stack_layer, Parallel): - hidden_states, input_tensor = self.adapter_parallel( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 4: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_stack_layer, BatchSplit): - hidden_states = self.adapter_batchsplit( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 5: We have a nested average block -> call average method - elif isinstance(adapter_stack_layer, Average): - hidden_states = self.adapter_average_output( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 6: We have a single adapter which is part of this module -> forward pass - elif adapter_stack_layer in self.adapters: - adapter_layer = self.adapters[adapter_stack_layer] - hidden_states, _, residual = adapter_layer.pre_forward(hidden_states, input_tensor, layer_norm) - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - hidden_states, up = layer_output[0], layer_output[2] - self._store_gating_score(adapter_stack_layer, layer_output[-1]) - # as this stack might be part of a fusion block, return the adapter up-projection output here - # together with the final output (with potential residuals & norms) if we reached the last block of the stack - if i == len(adapter_setup) - 1: - return hidden_states, up, input_tensor - # Case X: No adapter which is part of this module -> ignore - - # If we got here, we either had another nested composition block - # or no adapter was found. In both cases, we don't need to set the second return value for fusion - return hidden_states, None, input_tensor - - def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer_norm, lvl=0): + def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: BottleneckState) -> BottleneckState: + if isinstance(adapter_setup, AdapterCompositionBlock): + adapter_name = adapter_setup.first() + else: + adapter_name = adapter_setup + first_adapter = self.adapters[adapter_name] + hidden_states, _, residual = first_adapter.pre_forward( + state.hidden_states, state.input_tensor, state.layer_norm + ) + + return state._replace(hidden_states=hidden_states, adapter_residual=residual) + + def vslice(self, state: BottleneckState, slice_obj: slice) -> BottleneckState: + return BottleneckState( + state.hidden_states[slice_obj], + state.input_tensor[slice_obj], + state.adapter_residual[slice_obj], + state.layer_norm, + state.bottleneck_up[slice_obj] if state.bottleneck_up is not None else None, + ) + + def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState: + return BottleneckState( + torch.cat([state.hidden_states for state in states], dim=0), + torch.cat([state.input_tensor for state in states], dim=0), + torch.cat([state.adapter_residual for state in states], dim=0), + states[0].layer_norm, + torch.cat([state.bottleneck_up for state in states], dim=0) + if states[0].bottleneck_up is not None + else None, + ) + + def repeat(self, state: BottleneckState, channels: int) -> BottleneckState: + return BottleneckState( + state.hidden_states.repeat(channels, 1, 1), + state.input_tensor.repeat(channels, 1, 1), + state.adapter_residual.repeat(channels, 1, 1), + state.layer_norm, + state.bottleneck_up.repeat(channels, 1, 1) if state.bottleneck_up is not None else None, + ) + + def mean(self, states: List[NamedTuple], weights: torch.Tensor) -> NamedTuple: + return BottleneckState( + torch.mean(torch.stack([s.hidden_states for s in states], 0) * weights, dim=0), + states[0].input_tensor, + states[0].adapter_residual, + states[0].layer_norm, + states[0].bottleneck_up, + ) + + def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState: + adapter_layer = self.adapters[adapter_setup] + context = ForwardContext.get_context() + layer_output = adapter_layer( + state.hidden_states, + residual_input=state.adapter_residual, + output_gating=context.output_adapter_gating_scores, + ) + hidden_states, up = layer_output[0], layer_output[2] + self._store_gating_score(adapter_setup, layer_output[-1]) + + return BottleneckState(hidden_states, state.input_tensor, state.adapter_residual, state.layer_norm, up) + + def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0): """ Performs adapter fusion with the given adapters for the given input. """ @@ -217,44 +246,32 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer fusion_config = self.adapters_config.get_fusion(adapter_setup.name) last_adapter = self.adapters[adapter_setup.last()] hidden_states, query, residual = last_adapter.pre_forward( - hidden_states, input_tensor, layer_norm, fusion_config=fusion_config + state.hidden_states, state.input_tensor, state.layer_norm, fusion_config=fusion_config ) + state = state._replace(hidden_states=hidden_states, adapter_residual=residual) + + children_states = [] + for child in adapter_setup: + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, state, lvl=lvl + 1) + children_states.append(child_state) + else: + pass - up_list = [] - - for adapter_block in adapter_setup: - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - _, up, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - if up is not None: # could be none if stack is empty - up_list.append(up) - # Case 2: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - up = layer_output[2] - self._store_gating_score(adapter_block, layer_output[-1]) - up_list.append(up) - # Case 3: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - - if len(up_list) > 0: - up_list = torch.stack(up_list) + if len(children_states) > 0: + up_list = torch.stack([state.bottleneck_up for state in children_states]) up_list = up_list.permute(1, 2, 0, 3) fusion_output = self.adapter_fusion_layer[adapter_setup.name]( query, up_list, up_list, - residual, + state.adapter_residual, output_attentions=context.output_adapter_fusion_attentions, ) if context.output_adapter_fusion_attentions: @@ -263,291 +280,49 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer else: hidden_states = fusion_output - return hidden_states + return state._replace(hidden_states=hidden_states) - def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer_norm, lvl=0): + def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int = 0): """ Splits the given input between the given adapters. """ - # config of _first_ of splitted adapters is significant - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, query, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - # split hidden representations and residuals at split index - split_hidden_states = [ - hidden_states[:, : adapter_setup.split_index, :], - hidden_states[:, adapter_setup.split_index :, :], - ] - split_input_tensor = [ - input_tensor[:, : adapter_setup.split_index, :], - input_tensor[:, adapter_setup.split_index :, :], - ] - split_residual = [ - residual[:, : adapter_setup.split_index, :], - residual[:, adapter_setup.split_index :, :], - ] - - for i, adapter_block in enumerate(adapter_setup): - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - split_hidden_states[i], _, _ = self.adapter_stack( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 2: We have a nested split -> recursively call split - elif isinstance(adapter_block, Split): - split_hidden_states[i] = self.adapter_split( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 3: We have a nested batch split -> call batch split method - elif isinstance(adapter_block, BatchSplit): - split_hidden_states[i] = self.adapter_batchsplit( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 4: We have a nested average -> call average method - elif isinstance(adapter_block, Average): - split_hidden_states[i] = self.adapter_average_output( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 5: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - context = ForwardContext.get_context() - layer_output = adapter_layer( - split_hidden_states[i], - residual_input=split_residual[i], - output_gating=context.output_adapter_gating_scores, - ) - split_hidden_states[i] = layer_output[0] - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 6: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - - hidden_states = torch.cat(split_hidden_states, dim=1) - return hidden_states - - def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, layer_norm, lvl=0): - """ - For parallel execution of the adapters on the same input. This means that the input is repeated N times before - feeding it to the adapters (where N is the number of adapters). - """ - - context = ForwardContext.get_context() - if not context.adapters_parallelized: - orig_batch_size = input_tensor.shape[0] - input_tensor = input_tensor.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1) - hidden_states = hidden_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1) - context.adapters_parallelized = True - else: - # The base model should handle replication of input. - # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. - if hidden_states.shape[0] % adapter_setup.parallel_channels != 0: - raise ValueError( - "The total input batch size in a Parallel adapter block must be divisible by the number of" - " parallel channels." - ) - orig_batch_size = hidden_states.shape[0] // adapter_setup.parallel_channels - - # We assume all adapters have the same config - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - # sequentially feed different parts of the blown-up batch into different adapters - children_hidden = [] - for i, child in enumerate(adapter_setup): - # Case 1: We have a nested stack -> call stack method - if isinstance(child, Stack): - child_hidden_states, _, _ = self.adapter_stack( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 2: We have a nested batchsplit block -> call batchsplit method - elif isinstance(child, BatchSplit): - child_hidden_states = self.adapter_batchsplit( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 3: We have a nested average block -> call average method - elif isinstance(child, Average): - child_hidden_states = self.adapter_average_output( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif child in self.adapters: - adapter_layer = self.adapters[child] - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - residual_input=residual[i * orig_batch_size : (i + 1) * orig_batch_size], - output_gating=context.output_adapter_gating_scores, - ) - child_hidden_states = layer_output[0] - self._store_gating_score(child, layer_output[-1]) - children_hidden.append(child_hidden_states) - # Case 5: nesting other composition blocks is invalid - elif isinstance(child, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - child.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - else: - children_hidden.append(hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size]) - - # concatenate all outputs and return - hidden_states = torch.cat(children_hidden, 0) - return hidden_states, input_tensor - - def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_tensor, layer_norm, lvl=0): - if not sum(adapter_setup.batch_sizes) == hidden_states.shape[0]: + if sum(adapter_setup.splits) != state.hidden_states.shape[1]: raise IndexError( - "The given batch has a size of {} which is not compatible with batch_sizes {}".format( - hidden_states.shape[0], adapter_setup.batch_sizes + "The given input has sequence length {} which is not equal to the sum of splits {}".format( + state.hidden_states.shape[1], adapter_setup.splits ) ) - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - children_hidden = [] - for i, adapter_block in enumerate(adapter_setup): - # compute ids of sequences thet should be passed to the ith adapter + state = self.pre_block(adapter_setup, state) + + children_states = [] + for i, child in enumerate(adapter_setup): batch_idx = ( - sum(adapter_setup.batch_sizes[:i]), - sum(adapter_setup.batch_sizes[: i + 1]), + sum(adapter_setup.splits[:i]), + sum(adapter_setup.splits[: i + 1]), ) - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child, _, _ = self.adapter_stack( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 2: We have a nested split -> recursively call split - elif isinstance(adapter_block, Split): - child = self.adapter_split( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 3: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child = self.adapter_batchsplit( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 4: We have a nested average block -> call average method - elif isinstance(adapter_block, Average): - child = self.adapter_average_output( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 5: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - - adapter_layer = self.adapters[adapter_block] - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states[batch_idx[0] : batch_idx[1]], - residual_input=residual[batch_idx[0] : batch_idx[1]], - output_gating=context.output_adapter_gating_scores, - ) - children_hidden.append(layer_output[0]) - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 6: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore + child_state = BottleneckState( + state.hidden_states[:, batch_idx[0] : batch_idx[1], :], + state.input_tensor[:, batch_idx[0] : batch_idx[1], :], + state.adapter_residual[:, batch_idx[0] : batch_idx[1], :], + state.layer_norm, + state.bottleneck_up[:, batch_idx[0] : batch_idx[1], :] if state.bottleneck_up is not None else None, + ) + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, child_state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, child_state, lvl=lvl + 1) + children_states.append(child_state) else: - children_hidden.append(hidden_states[batch_idx]) - - hidden_states = torch.cat(children_hidden, 0) - return hidden_states - - def adapter_average_output(self, adapter_setup: Average, hidden_states, input_tensor, layer_norm, lvl=0): - """ - For averaging the output representations of multiple adapters. - """ - context = ForwardContext.get_context() - - # We assume all adapters have the same config - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - children_hidden = [] - - for adapter_block in adapter_setup: - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child, _, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 2: We have a nested split block -> call split method - elif isinstance(adapter_block, Split): - child = self.adapter_split(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 3: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child = self.adapter_batchsplit(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - children_hidden.append(layer_output[0]) - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 5: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore + pass - weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(hidden_states.device) - hidden_states = torch.mean(torch.cat(children_hidden, 0) * weights, 0) + hidden_states = torch.cat([child.hidden_states for child in children_states], dim=1) + return state._replace(hidden_states=hidden_states) - return hidden_states - - def adapter_layer_forward(self, hidden_states, residual_input, layer_norm): + def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm): """Forward pass through the adapter layer. NOTE: This method should only be called if the calling module directly inherits from BottleneckLayer. Otherwise, call the regular forward() method. @@ -564,30 +339,13 @@ def adapter_layer_forward(self, hidden_states, residual_input, layer_norm): (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input) # Replicate in both directions as residual might be larger (e.g. GPT-J) (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states) - adapter_setup = self.get_active_setup(self.adapters) + adapter_setup = self.get_active_setup() if adapter_setup is not None: input_hidden_states = hidden_states - if isinstance(adapter_setup, Stack): - hidden_states, _, residual_input = self.adapter_stack( - adapter_setup, hidden_states, residual_input, layer_norm - ) - elif isinstance(adapter_setup, Fuse): - hidden_states = self.adapter_fusion(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Split): - hidden_states = self.adapter_split(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Parallel): - # notice that we are overriding input tensor here to keep the same dim as hidden_states for the residual - # in case we were blowing up the batch for parallel processing of multiple adapters for the same input - hidden_states, residual_input = self.adapter_parallel( - adapter_setup, hidden_states, residual_input, layer_norm - ) - elif isinstance(adapter_setup, BatchSplit): - hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Average): - hidden_states = self.adapter_average_output(adapter_setup, hidden_states, residual_input, layer_norm) - else: - raise ValueError(f"Invalid adapter setup {adapter_setup}") + state = BottleneckState(hidden_states, residual_input, residual_input, layer_norm) + state = self.compose(adapter_setup, state) + hidden_states, residual_input, _, _, _ = state last_adapter = self.adapters[adapter_setup.last()] hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm) @@ -610,4 +368,4 @@ def forward(self, hidden_states, residual_input, layer_norm): Returns: torch.Tensor: Output hidden states of the adapter layer. """ - return self.adapter_layer_forward(hidden_states, residual_input, layer_norm) + return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 977fe8ae88..a4c66c830c 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -94,6 +94,8 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: class LoRALayer(AdapterLayerBase): + adapter_modules_name = "loras" + def __init__( self, location_key: str, model_config: PretrainedConfig, adapters_config: ModelAdaptersConfig, *args, **kwargs ): @@ -313,7 +315,7 @@ def T(w): return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w if not self.merged: - adapter_setup = self.get_active_setup(self.loras) + adapter_setup = self.get_active_setup() if adapter_setup is not None: if len(adapter_setup) == 1: lora = self.loras[adapter_setup[0]] @@ -496,7 +498,7 @@ def T(w): return torch.t(w) if self.fan_in_fan_out else w if not self.merged: - adapter_setup = self.get_active_setup(self.loras) + adapter_setup = self.get_active_setup() if adapter_setup is not None: if len(adapter_setup) == 1: result = F.linear(x, T(self.weight), bias=self.bias) diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 06ebd83f50..6630d8f606 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, NamedTuple, Optional, Union import torch import torch.nn.functional as F @@ -10,7 +10,7 @@ from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel from ..configuration import ModelAdaptersConfig, PrefixTuningConfig from ..context import AdapterSetup, ForwardContext -from .adapter_layer_base import AdapterLayerBase +from .adapter_layer_base import ComposableAdapterLayerBase from .modeling import Activation_Function_Class @@ -244,7 +244,28 @@ def forward(self, *args, **kwargs): return prefix_states -class PrefixTuningShim(AdapterLayerBase, nn.Module): +class PrefixTuningState(NamedTuple): + """ + Models the input and output states of a prefix tuning layer. + + Args: + key_states (torch.Tensor): The key states of the attention layer. + value_states (torch.Tensor): The value states of the attention layer. + residual_input (torch.Tensor): The residual input of the attention layer. + attention_mask (torch.Tensor, optional): The attention mask of the attention layer. + invert_mask (bool): Whether the attention mask is inverted (ie. using '1' for padding). + + """ + + key_states: torch.Tensor + value_states: torch.Tensor + residual_input: torch.Tensor + attention_mask: Optional[torch.Tensor] + invert_mask: bool + idx_slice: Optional[slice] = None + + +class PrefixTuningShim(ComposableAdapterLayerBase, nn.Module): """ Representation of a Prefix Tuning layer within one Transformer layer. This class implements `AdapterLayerBase` for compatibility with adapters. It uses `PrefixTuningPool` in the background and `set_pool()` must be called after @@ -256,6 +277,9 @@ class PrefixTuningShim(AdapterLayerBase, nn.Module): config (:class:`~transformers.PretrainedConfig`): The model config. """ + adapter_modules_name = "prefixes" + supported_compositions = [Stack, Parallel, BatchSplit] + def __init__( self, location_key: str, @@ -373,63 +397,31 @@ def get_adapter(self, adapter_name): return None - def single_forward( - self, - adapter_name: str, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - ): - prefix_id = self.prefixes[adapter_name] - batch_size = key_states.size(0) - - # Retrieve pre-computed prefix states from context - context = ForwardContext.get_context() - # batch_size x n_heads x prefix_length x n_embd_per_head - prefix_keys, prefix_values = context.prefix_states[adapter_name][self.location_key][prefix_id] - - # select index range for batch split - if idx_range is not None: - prefix_keys = prefix_keys[idx_range] - prefix_values = prefix_values[idx_range] - - if adapter_name in self.prefix_gates: - gate = self.prefix_gates[adapter_name] - gate_output = torch.mean(torch.sigmoid(gate(residual_input)), dim=1) - self._store_gating_score(adapter_name, gate_output) - gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1) - gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1) - prefix_keys = prefix_keys * gate_output_key - prefix_values = prefix_values * gate_output_value - - # replicate for Parallel block - prefix_keys, prefix_values = adjust_tensors_for_parallel(key_states, prefix_keys, prefix_values) - - key_states = torch.cat([prefix_keys, key_states], dim=2) - value_states = torch.cat([prefix_values, value_states], dim=2) - if attention_mask is not None: - if attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) - prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(attention_mask.device) - else: - prefix_mask = torch.ones(batch_size, 1, attention_mask.size(2), prefix_keys.size(2)).to( - attention_mask.device - ) - if invert_mask: - prefix_mask = 1.0 - prefix_mask - (prefix_mask,) = adjust_tensors_for_parallel(attention_mask, prefix_mask) - attention_mask = torch.cat([prefix_mask, attention_mask], dim=-1) - - return key_states, value_states, residual_input, attention_mask + def vslice(self, state: PrefixTuningState, slice_obj: slice) -> PrefixTuningState: + if state.idx_slice is None: + split_idx_slice = slice_obj + else: + split_idx_slice = slice( + state.idx_slice.start + slice_obj.start, + state.idx_slice.start + slice_obj.stop, + ) + return PrefixTuningState( + key_states=state.key_states[slice_obj], + value_states=state.value_states[slice_obj], + residual_input=state.residual_input[slice_obj], + attention_mask=state.attention_mask[slice_obj] if state.attention_mask is not None else None, + invert_mask=state.invert_mask, + idx_slice=split_idx_slice, + ) - def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): - """Pads all key & value states to the lFongest prefix length in the current batch. + def pad_and_concat(self, states: List[PrefixTuningState]) -> PrefixTuningState: + """Pads all key & value states to the longest prefix length in the current batch. This is required e.g. for stacked prefix tunings. """ + max_prefix_length = max([state.key_states.shape[-2] for state in states]) all_key_states, all_value_states, all_residual_input, all_attention_mask = [], [], [], [] - for key_states, value_states, residual_input, attention_mask in outputs: + for state in states: + key_states, value_states, residual_input, attention_mask = state[:4] # pad sizes pad_length = max_prefix_length - key_states.shape[-2] pad_size = (0, 0, pad_length, 0) @@ -445,7 +437,7 @@ def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): attention_mask, (max_prefix_length - attention_mask.shape[-1], 0), "constant", - 1.0 if invert_mask else 0.0, + 1.0 if state.invert_mask else 0.0, ) all_key_states.append(key_states) @@ -458,294 +450,87 @@ def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): all_residual_input = torch.cat(all_residual_input, dim=0) all_attention_mask = torch.cat(all_attention_mask, dim=0) if attention_mask is not None else None - return all_key_states, all_value_states, all_residual_input, all_attention_mask + return PrefixTuningState( + key_states=all_key_states, + value_states=all_value_states, + residual_input=all_residual_input, + attention_mask=all_attention_mask, + invert_mask=states[0].invert_mask, + idx_slice=states[0].idx_slice, + ) - def adapter_stack( - self, - adapter_setup: Stack, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - for adapter_stack_layer in adapter_setup: - # Break if setup is too deep - if isinstance(adapter_stack_layer, AdapterCompositionBlock) and lvl >= 1: - raise ValueError( - "Specified adapter setup is too deep. Cannot have {} at level {}".format( - adapter_stack_layer.__class__.__name__, lvl - ) - ) - # We have a nested parallel layer -> call parallel method - elif isinstance(adapter_stack_layer, Parallel): - key_states, value_states, residual_input, attention_mask = self.adapter_parallel( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask=invert_mask, - idx_range=idx_range, - lvl=lvl + 1, - ) - # We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_stack_layer, BatchSplit): - key_states, value_states, residual_input, attention_mask = self.adapter_batchsplit( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask=invert_mask, - idx_range=idx_range, - lvl=lvl + 1, - ) - # We have a single prefix tuning module part of this model -> forward pass - elif adapter_stack_layer in self.prefixes: - key_states, value_states, _, attention_mask = self.single_forward( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask, - idx_range=idx_range, - ) - # Nesting other composition blocks is invalid - elif isinstance(adapter_stack_layer, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_stack_layer.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. + def repeat(self, state: PrefixTuningState, channels: int) -> PrefixTuningState: + if state.attention_mask is not None: + if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) + attention_mask = state.attention_mask.repeat(channels, 1) else: - raise ValueError(f"Unknown prefix tuning name '{adapter_stack_layer}'.") - - return key_states, value_states, residual_input, attention_mask - - def adapter_parallel( - self, - adapter_setup: Parallel, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - """ - For parallel execution of the adapters on the same input. This means that the input is repeated N times before - feeding it to the adapters (where N is the number of adapters). - """ - - context = ForwardContext.get_context() - if not context.adapters_parallelized: - orig_batch_size = residual_input.shape[0] - residual_input = residual_input.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - key_states = key_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - value_states = value_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - if attention_mask is not None: - if attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) - attention_mask = attention_mask.repeat(self.adapters_config.active_setup.parallel_channels, 1) - else: - attention_mask = attention_mask.repeat( - self.adapters_config.active_setup.parallel_channels, 1, 1, 1 - ) - context.adapters_parallelized = True + attention_mask = state.attention_mask.repeat(channels, 1, 1, 1) else: - # The base model should handle replication of input. - # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. - if residual_input.shape[0] % adapter_setup.parallel_channels != 0: - raise ValueError( - "The total input batch size in a Parallel adapter block must be divisible by the number of" - " parallel channels." - ) - orig_batch_size = residual_input.shape[0] // adapter_setup.parallel_channels - - # sequentially feed different parts of the blown-up batch into different adapters - children_outputs = [] - # track which prefix is longest for padding in the end - max_prefix_length = 0 - for i, child in enumerate(adapter_setup): - # construct inputs to child modules - inputs = { - "key_states": key_states[i * orig_batch_size : (i + 1) * orig_batch_size], - "value_states": value_states[i * orig_batch_size : (i + 1) * orig_batch_size], - "residual_input": residual_input[i * orig_batch_size : (i + 1) * orig_batch_size], - "attention_mask": attention_mask[i * orig_batch_size : (i + 1) * orig_batch_size] - if attention_mask is not None - else None, - "invert_mask": invert_mask, - "idx_range": idx_range, - } + attention_mask = None + return PrefixTuningState( + key_states=state.key_states.repeat(channels, 1, 1, 1), + value_states=state.value_states.repeat(channels, 1, 1, 1), + residual_input=state.residual_input.repeat(channels, 1, 1), + attention_mask=attention_mask, + invert_mask=state.invert_mask, + idx_slice=state.idx_slice, + ) - # Case 1: We have a nested stack -> call stack method - if isinstance(child, Stack): - child_outputs = self.adapter_stack( - child, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 2. We have a nested batchsplit block -> call batchsplit method - elif isinstance(child, BatchSplit): - child_outputs = self.adapter_batchsplit( - child, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 3: We have a single adapter which is part of this module -> forward pass - elif child in self.prefixes: - child_outputs = self.single_forward( - child, - **inputs, - ) - children_outputs.append(child_outputs) - # Case 4: nesting other composition blocks is invalid - elif isinstance(child, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - child.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. - else: - raise ValueError(f"Unknown prefix tuning name '{child}'.") + def mean(self, states: List[PrefixTuningState], weights: torch.Tensor) -> PrefixTuningState: + # TODO implement average composition + raise NotImplementedError() - # update max prefix length - current_prefix_length = child_outputs[0].shape[-2] - if current_prefix_length > max_prefix_length: - max_prefix_length = current_prefix_length + def compose_single(self, adapter_setup: str, state: PrefixTuningState, lvl: int = 0) -> PrefixTuningState: + prefix_id = self.prefixes[adapter_setup] + batch_size = state.key_states.size(0) - # concatenate all outputs and return - key_states, value_states, residual_input, attention_mask = self._pad_and_concat( - max_prefix_length, children_outputs, invert_mask=invert_mask - ) - return key_states, value_states, residual_input, attention_mask + # Retrieve pre-computed prefix states from context + context = ForwardContext.get_context() + # batch_size x n_heads x prefix_length x n_embd_per_head + prefix_keys, prefix_values = context.prefix_states[adapter_setup][self.location_key][prefix_id] + + # Select index range for batch split + # Ignore slices that go beyond the prefix states bsz + # (this is the case for slices produced by Parallel blocks which operate on replicated kv states) + if state.idx_slice is not None and state.idx_slice.start < prefix_keys.size(0): + prefix_keys = prefix_keys[state.idx_slice] + prefix_values = prefix_values[state.idx_slice] + + if adapter_setup in self.prefix_gates: + gate = self.prefix_gates[adapter_setup] + gate_output = torch.mean(torch.sigmoid(gate(state.residual_input)), dim=1) + self._store_gating_score(adapter_setup, gate_output) + gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1) + gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1) + prefix_keys = prefix_keys * gate_output_key + prefix_values = prefix_values * gate_output_value - def adapter_batchsplit( - self, - adapter_setup: BatchSplit, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - if not sum(adapter_setup.batch_sizes) == key_states.shape[0]: - raise IndexError( - "The given batch has a size of {} which is not compatible with batch_sizes {}".format( - key_states.shape[0], adapter_setup.batch_sizes - ) - ) + # Replicate for Parallel block + prefix_keys, prefix_values = adjust_tensors_for_parallel(state.key_states, prefix_keys, prefix_values) - children_outputs = [] - # track which prefix is longest for padding in the end - max_prefix_length = 0 - for i, adapter_block in enumerate(adapter_setup): - # compute ids of sequences that should be passed to the ith adapter - if idx_range is None: - split_idx_range = range( - sum(adapter_setup.batch_sizes[:i]), - sum(adapter_setup.batch_sizes[: i + 1]), - ) + key_states = torch.cat([prefix_keys, state.key_states], dim=2) + value_states = torch.cat([prefix_values, state.value_states], dim=2) + if state.attention_mask is not None: + if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) + prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(state.attention_mask.device) else: - split_idx_range = range( - idx_range.start + sum(adapter_setup.batch_sizes[:i]), - idx_range.start + sum(adapter_setup.batch_sizes[: i + 1]), - ) - inputs = { - "key_states": key_states[split_idx_range], - "value_states": value_states[split_idx_range], - "residual_input": residual_input[split_idx_range], - "attention_mask": attention_mask[split_idx_range] if attention_mask is not None else None, - "invert_mask": invert_mask, - "idx_range": split_idx_range, - } - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child_outputs = self.adapter_stack( - adapter_block, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 2: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child_outputs = self.adapter_batchsplit( - adapter_block, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.prefixes: - child_outputs = self.single_forward( - adapter_block, - **inputs, + prefix_mask = torch.ones(batch_size, 1, state.attention_mask.size(2), prefix_keys.size(2)).to( + state.attention_mask.device ) - children_outputs.append(child_outputs) - # Case 5: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. - else: - raise ValueError(f"Unknown prefix tuning name '{adapter_block}'.") - - # update max prefix length - current_prefix_length = child_outputs[0].shape[-2] - if current_prefix_length > max_prefix_length: - max_prefix_length = current_prefix_length + if state.invert_mask: + prefix_mask = 1.0 - prefix_mask + (prefix_mask,) = adjust_tensors_for_parallel(state.attention_mask, prefix_mask) + attention_mask = torch.cat([prefix_mask, state.attention_mask], dim=-1) + else: + attention_mask = None - # concatenate all outputs and return - key_states, value_states, residual_input, attention_mask = self._pad_and_concat( - max_prefix_length, children_outputs, invert_mask=invert_mask - ) - return key_states, value_states, residual_input, attention_mask + return state._replace(key_states=key_states, value_states=value_states, attention_mask=attention_mask) def forward(self, key_states, value_states, residual_input, attention_mask=None, invert_mask=True): - adapter_setup = self.get_active_setup(self.prefixes) + adapter_setup = self.get_active_setup() if adapter_setup is not None: - if isinstance(adapter_setup, Stack): - key_states, value_states, _, attention_mask = self.adapter_stack( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - elif isinstance(adapter_setup, Parallel): - key_states, value_states, _, attention_mask = self.adapter_parallel( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - elif isinstance(adapter_setup, BatchSplit): - key_states, value_states, _, attention_mask = self.adapter_batchsplit( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - else: - raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with prefix tuning.") + state = PrefixTuningState(key_states, value_states, residual_input, attention_mask, invert_mask) + state = self.compose(adapter_setup, state) + key_states, value_states, residual_input, attention_mask = state[:4] return key_states, value_states, attention_mask diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index d43a840552..b0528ab7b6 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -83,7 +83,7 @@ def test_simple_split(self): model = self.build_model() # pass over split setup - model.set_active_adapters(Split("a", "b", 64)) + model.set_active_adapters(Split("a", "b", splits=64)) self.training_pass(model) @@ -93,7 +93,7 @@ def test_stacked_split(self): model = self.build_model() # split into two stacks - model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), split_index=64)) + model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), splits=64)) self.training_pass(model) @@ -118,7 +118,7 @@ def test_mixed_stack(self): model.add_adapter_fusion(Fuse("a", "b")) model.to(torch_device) - model.set_active_adapters(Stack("a", Split("c", "d", split_index=64), Fuse("a", "b"))) + model.set_active_adapters(Stack("a", Split("c", "d", splits=64), Fuse("a", "b"))) self.training_pass(model) @@ -128,7 +128,7 @@ def test_nested_split(self): model = self.build_model() # split into two stacks - model.set_active_adapters(Split(Split("a", "b", split_index=32), "c", split_index=64)) + model.set_active_adapters(Split(Split("a", "b", splits=32), "c", splits=64)) self.training_pass(model) From 7770f43d6385a4954d66b334c4ef36fe4aa472de Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 8 Oct 2023 01:08:11 +0200 Subject: [PATCH 03/10] Update split block docs --- docs/adapter_composition.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 1ee26806ea..05f85f3fdb 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -160,10 +160,10 @@ In the example, `attention_scores` holds a dictionary of the following form: Splitting the input between two adapters using the 'Split' block. ``` -The `Split` block can be used to split an input sequence between two adapters. -This is done by specifying a split index, at which the sequences should be divided. +The `Split` block can be used to split an input sequence between multiple adapters. +This is done by specifying split indices at which the sequences should be divided. In the following example, we split each input sequence between adapters `g` and `h`. -For each sequence, all tokens from 0 up to 63 are forwarded through `g` while all tokens beginning at index 64 are forwarded through `h`: +For each sequence, all tokens from 0 up to 63 are forwarded through `g` while the next 64 tokens are forwarded through `h`: ```python import adapters.composition as ac @@ -173,7 +173,7 @@ import adapters.composition as ac model.add_adapter("g") model.add_adapter("h") -model.active_adapters = ac.Split("g", "h", split_index=64) +model.active_adapters = ac.Split("g", "h", splits=[64, 64]) ``` ## `BatchSplit` @@ -286,7 +286,7 @@ E.g., we can nest a `Split` block within a `Stack` of adapters: ```python import adapters.composition as ac -model.active_adapters = ac.Stack("a", ac.Split("b", "c", split_index=60)) +model.active_adapters = ac.Stack("a", ac.Split("b", "c", splits=60)) ``` However, combinations of adapter composition blocks cannot be arbitrarily deep. All currently supported possibilities are visualized in the table below. From 55fdc0cbe2f695914108a9c0e208127b13bc617e Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 9 Oct 2023 21:31:55 +0200 Subject: [PATCH 04/10] PrefixTuningShim -> PrefixTuningLayer --- docs/contributing/adding_adapters_to_a_model.md | 2 +- src/adapters/methods/prefix_tuning.py | 6 +++--- src/adapters/model_mixin.py | 4 ++-- src/adapters/models/albert/mixin_albert.py | 4 ++-- src/adapters/models/bart/mixin_bart.py | 4 ++-- src/adapters/models/beit/mixin_beit.py | 4 ++-- src/adapters/models/bert/mixin_bert.py | 4 ++-- src/adapters/models/clip/mixin_clip.py | 6 ++++-- src/adapters/models/deberta/mixin_deberta.py | 4 ++-- src/adapters/models/deberta_v2/mixin_deberta_v2.py | 4 ++-- src/adapters/models/distilbert/mixin_distilbert.py | 4 ++-- src/adapters/models/gpt2/mixin_gpt2.py | 4 ++-- src/adapters/models/gptj/mixin_gptj.py | 4 ++-- src/adapters/models/llama/mixin_llama.py | 4 ++-- src/adapters/models/t5/mixin_t5.py | 4 ++-- src/adapters/models/vit/mixin_vit.py | 4 ++-- 16 files changed, 34 insertions(+), 32 deletions(-) diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md index f574bb806a..430e146edb 100644 --- a/docs/contributing/adding_adapters_to_a_model.md +++ b/docs/contributing/adding_adapters_to_a_model.md @@ -27,7 +27,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/ Date: Tue, 10 Oct 2023 20:50:18 +0200 Subject: [PATCH 05/10] `adapter_layer_forward()` -> `bottleneck_layer_forward()` --- docs/contributing/adding_adapters_to_a_model.md | 2 +- src/adapters/models/beit/modeling_beit.py | 4 ++-- src/adapters/models/bert/modeling_bert.py | 4 ++-- .../models/bert_generation/modeling_bert_generation.py | 4 ++-- src/adapters/models/deberta/modeling_deberta.py | 4 ++-- src/adapters/models/deberta_v2/modeling_deberta_v2.py | 4 ++-- src/adapters/models/electra/modeling_electra.py | 4 ++-- src/adapters/models/roberta/modeling_roberta.py | 4 ++-- src/adapters/models/t5/modeling_t5.py | 6 +++--- src/adapters/models/vit/modeling_vit.py | 4 ++-- src/adapters/models/xlm_roberta/modeling_xlm_roberta.py | 4 ++-- src/adapters/models/xmod/modeling_xmod.py | 4 ++-- 12 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md index 430e146edb..9306e8d92f 100644 --- a/docs/contributing/adding_adapters_to_a_model.md +++ b/docs/contributing/adding_adapters_to_a_model.md @@ -28,7 +28,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/ torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -149,5 +149,5 @@ class BertOutputWithAdapters(BertOutputAdaptersMixin, BertOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index c21d2a3f4d..8f083fe295 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -36,7 +36,7 @@ class BertGenerationSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertGene def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -153,5 +153,5 @@ class BertGenerationOutputWithAdapters(BertOutputAdaptersMixin, BertGenerationOu def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 462685a85d..8197c19fb6 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -33,7 +33,7 @@ class DebertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaSelfOutp def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -41,7 +41,7 @@ class DebertaOutputWithAdapters(BertOutputAdaptersMixin, DebertaOutput): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 79cd4e6a34..082e77a721 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -34,7 +34,7 @@ class DebertaV2SelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaV2Self def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -43,7 +43,7 @@ class DebertaV2OutputWithAdapters(BertOutputAdaptersMixin, DebertaV2Output): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index 0412b4dc10..35552782ce 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -122,7 +122,7 @@ class ElectraSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, ElectraSelfOutp def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -130,5 +130,5 @@ class ElectraOutputWithAdapters(BertOutputAdaptersMixin, ElectraOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index cf37a337ae..47a8ed35a9 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -142,7 +142,7 @@ class RobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, RobertaSelfOutp def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -151,5 +151,5 @@ class RobertaOutputWithAdapters(BertOutputAdaptersMixin, RobertaOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 7d7e467f0a..3440a4bb73 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -45,7 +45,7 @@ class T5LayerFFWithAdapters(T5FFLayerAdaptersMixin, T5LayerFF): def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = self.adapter_layer_forward( + hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None ) return hidden_states @@ -207,7 +207,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = self.adapter_layer_forward( + hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None ) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -239,7 +239,7 @@ def forward( query_length=query_length, output_attentions=output_attentions, ) - layer_output = self.adapter_layer_forward( + layer_output = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None ) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py index 4ffb61d5f6..bb0fadd2ca 100644 --- a/src/adapters/models/vit/modeling_vit.py +++ b/src/adapters/models/vit/modeling_vit.py @@ -72,7 +72,7 @@ class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None) + hidden_states = self.output_adapters.bottleneck_layer_forward(hidden_states, input_tensor, None) return hidden_states @@ -94,7 +94,7 @@ def forward( attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None) + hidden_states = self.attention_adapters.bottleneck_layer_forward(attention_output, hidden_states, None) # in ViT, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index cd8dd9bf08..a8d22284b7 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -146,7 +146,7 @@ class XLMRobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XLMRobertaSe def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -155,5 +155,5 @@ class XLMRobertaOutputWithAdapters(BertOutputAdaptersMixin, XLMRobertaOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index 3a3a38066f..b772321667 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -140,7 +140,7 @@ class XmodSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XmodSelfOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, None) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, None) return hidden_states @@ -152,5 +152,5 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_ layer_norm = self.adapter_layer_norm elif self.adapter_reuse_layer_norm: layer_norm = self.LayerNorm - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, layer_norm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, layer_norm) return hidden_states From 199fef53c084b6157f16a4c4f63e641111c68f4d Mon Sep 17 00:00:00 2001 From: calpt Date: Tue, 10 Oct 2023 21:37:22 +0200 Subject: [PATCH 06/10] Fix init --- src/adapters/methods/adapter_layer_base.py | 3 +++ src/adapters/methods/bottleneck.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index 7a70a4d764..24069586a4 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -117,6 +117,9 @@ class ComposableAdapterLayerBase(AdapterLayerBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._init_mapping() + + def _init_mapping(self): self.composition_to_func_map = { Stack: self.compose_stack, Fuse: self.compose_fuse, diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index b802e960e2..7ec097d806 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -48,6 +48,7 @@ def __init__(self, location_key: str): self.location_key = location_key def init_adapters(self, model_config, adapters_config): + self._init_mapping() self.model_config = model_config self.adapters_config = adapters_config self.adapters = nn.ModuleDict(dict()) From 9cdf62d477d7d77fff2659e7cd8c8b49425c1470 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 11 Oct 2023 22:17:48 +0200 Subject: [PATCH 07/10] Test fixes --- src/adapters/methods/adapter_layer_base.py | 4 ++-- tests_adapters/composition/test_adapter_composition.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index 24069586a4..991839d144 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -238,7 +238,7 @@ def check_composition_valid(self, parent: AdapterCompositionBlock, child: Adapte ValueError: If the composition is invalid. """ # Break if setup is too deep - if lvl >= 1: + if type(parent) == Stack and lvl >= 1: raise ValueError( "Specified adapter setup is too deep. Cannot have {} at level {}".format(child.__class__.__name__, lvl) ) @@ -400,7 +400,7 @@ def compose_average(self, adapter_setup: Average, state: NamedTuple, lvl: int = else: pass - weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(state[0].device) + weights = torch.tensor(adapter_setup.weights)[:, None, None, None].to(state[0].device) state = self.mean(children_states, weights) return state diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index b0528ab7b6..2670488cb9 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -21,10 +21,10 @@ def test_to_deep(self): def test_invalid_nesting_fusion(self): self.assertRaises(ValueError, lambda: parse_composition(Fuse(Fuse("a", "b"), "c"))) - self.assertRaises(ValueError, lambda: parse_composition(Fuse(Split("a", "b", 128), "c"))) + self.assertRaises(ValueError, lambda: parse_composition(Fuse(Split("a", "b", splits=128), "c"))) def test_invalid_nesting_split(self): - self.assertRaises(ValueError, lambda: parse_composition(Split("a", Fuse("b", "c"), 128))) + self.assertRaises(ValueError, lambda: parse_composition(Split("a", Fuse("b", "c"), splits=128))) @require_torch @@ -224,9 +224,9 @@ def test_average(self): model.set_active_adapters(Average("a", "b", "c", "d")) inputs = {} - inputs["input_ids"] = ids_tensor((1, 128), 1000) + inputs["input_ids"] = ids_tensor((2, 128), 1000) logits = model(**inputs).logits - self.assertEqual(logits.shape, (1, 2)) + self.assertEqual(logits.shape, (2, 2)) class PrefixTuningCompositionTest(AdapterCompositionTest): From 49521054070868d279564f9d2b567becf4bb01cd Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 11 Oct 2023 22:56:12 +0200 Subject: [PATCH 08/10] style --- src/adapters/methods/adapter_layer_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index 991839d144..ec26142f0f 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -238,7 +238,7 @@ def check_composition_valid(self, parent: AdapterCompositionBlock, child: Adapte ValueError: If the composition is invalid. """ # Break if setup is too deep - if type(parent) == Stack and lvl >= 1: + if isinstance(parent, Stack) and lvl >= 1: raise ValueError( "Specified adapter setup is too deep. Cannot have {} at level {}".format(child.__class__.__name__, lvl) ) From 216669f08151fa36389fc3291a62bbfd57d50b6b Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 19 Oct 2023 10:09:23 +0200 Subject: [PATCH 09/10] Update src/adapters/methods/bottleneck.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Leon Engländer --- src/adapters/methods/bottleneck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 7ec097d806..c150191695 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -215,7 +215,7 @@ def repeat(self, state: BottleneckState, channels: int) -> BottleneckState: state.bottleneck_up.repeat(channels, 1, 1) if state.bottleneck_up is not None else None, ) - def mean(self, states: List[NamedTuple], weights: torch.Tensor) -> NamedTuple: + def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> BottleneckState: return BottleneckState( torch.mean(torch.stack([s.hidden_states for s in states], 0) * weights, dim=0), states[0].input_tensor, From f3099db39b5eb16d238b0923b36f3d321a45933c Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 25 Oct 2023 22:08:42 +0200 Subject: [PATCH 10/10] Add documentation for new layer base & update contributing guide --- docs/classes/adapter_layer.rst | 11 ++++- docs/classes/adapter_modules.rst | 7 --- docs/contributing/adding_adapter_methods.md | 53 ++++++++++++++------- src/adapters/__init__.py | 4 +- src/adapters/methods/adapter_layer_base.py | 40 ++++++++++++++++ src/adapters/methods/prefix_tuning.py | 1 + 6 files changed, 89 insertions(+), 27 deletions(-) delete mode 100644 docs/classes/adapter_modules.rst diff --git a/docs/classes/adapter_layer.rst b/docs/classes/adapter_layer.rst index d76d13dd52..01233d6328 100644 --- a/docs/classes/adapter_layer.rst +++ b/docs/classes/adapter_layer.rst @@ -1,5 +1,12 @@ -BottleneckLayer +Adapter Implementation ======================= -.. autoclass:: adapters.BottleneckLayer +The following classes define the common interfaces for all adapter methods. +They further hold logic shared by all adapter implementations. +All newly added adapter methods should inherit from either one of these classes. + +.. autoclass:: adapters.AdapterLayerBase + :members: + +.. autoclass:: adapters.ComposableAdapterLayerBase :members: diff --git a/docs/classes/adapter_modules.rst b/docs/classes/adapter_modules.rst deleted file mode 100644 index 46056142bd..0000000000 --- a/docs/classes/adapter_modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -Adapter Modules -=============== - -Classes implementing task and language adapters. - -.. automodule:: adapters.modeling - :members: diff --git a/docs/contributing/adding_adapter_methods.md b/docs/contributing/adding_adapter_methods.md index de3f1937e3..29a7801579 100644 --- a/docs/contributing/adding_adapter_methods.md +++ b/docs/contributing/adding_adapter_methods.md @@ -20,28 +20,49 @@ Thus, each adapter method implementation at least should provide two classes: - a configuration class deriving from `AdapterConfigBase` that provides attributes for all configuration options of the method - a module class deriving from the abstract `AdapterLayerBase` that provides the method parameters and a set of standard adapter management functions + - modules supporting [adapter composition](https://docs.adapterhub.ml/adapter_composition.html) should instead derive from `ComposableAdapterLayerBase` -**📝 Steps** +### Configuration -- All configuration classes reside in `src/transformers/adapters/configuration.py`. - To add a new configuration class for a new method, create a new subclass of `AdapterConfigBase`. +All configuration classes reside in `src/adapters/configuration/adapter_config.py`. +- To add a new configuration class for a new method, create a new subclass of [`AdapterConfigBase`](adapters.AdapterConfigBase). Make sure to set the `architecture` attribute in your class. - - Finally, also make sure the config class is added to the `__init__.py` files in `src/transformers/adapters` and `src/transformers`. -- The `AdapterLayerBase` class from which any new adapter modules should derive resides in `src/transformers/adapters/layer.py`. - - This abstract base class defines a set of methods that should be implemented by each deriving class, - including methods for adding, enabling and deleting adapter weights. - - Most importantly, the module classes deriving from this base class should implement the forward pass through an adaptation component. - - The concrete implementation of these classes heavily depends on the specifics of the adapter method. - For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters. -- To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations. - - This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see `src/transformers/adapters/mixins`) or directly as submodules of the respective model components. - - The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation. - Please try to integrate any new adapter method into every model class when it's reasonable. - You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html. +- Finally, also make sure the config class is added to the `__init__.py` files in `src/adapters`. + +### Modeling + +All adapter method implementations reside in `src/adapters/methods`. + +#### For methods **without** composition support + +The [`AdapterLayerBase`](adapters.AdapterLayerBase) class from which any new adapter modules should derive resides in `src/adapters/methods/adapter_layer_base.py`. +- This abstract base class defines a set of methods that should be implemented by each deriving class, +including methods for adding, enabling and deleting adapter weights. These methods are marked as abstract in the base class. See [`AdapterLayerBase`](adapters.AdapterLayerBase) for details. +- Most importantly however, the module classes deriving from this base class should implement the forward pass through an adaptation component. +- The concrete implementation of these classes heavily depends on the specifics of the adapter method. + +#### For methods **with** composition support + +The [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) class (as subclass of [`AdapterLayerBase`](adapters.AdapterLayerBase)), which resides in `src/adapters/methods/adapter_layer_base.py` provides the basic skeleton for implementing adapter composition. +- Your deriving module class firstly should implement all methods required by [`AdapterLayerBase`](adapters.AdapterLayerBase). See section above for details. +- For adapter composition, the pre-implemented `compose()` method constitutes the main entry-point. This method should be called during the forward pass of your adapter module. +- `compose()` expects a `state` object, which is a generic named tuple object defined by your adapter method. This state object should hold all tensors (such as hidden states, attention masks etc.) and state attributes required for your adapter implementation. See `BottleneckState` for an example. +- Implementations for specific composition blocks are given in methods starting with `compose_`. Some composition blocks provide generic default implementations, some must be implemented by the deriving class if they should be supported. Make sure to list all supported composition blocks in the `supported_compositions` class attribute of your deriving module. +- In any case, a small set of helper methods should be implemented by any deriving module to support basic composition logic. These are marked as abstract methods in [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) and currently consist of the following: vslice(), pad_and_concat(), repeat(), mean(), compose_single(). See [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) for details. + +For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters. + +#### For all methods + +To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations. +- This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see modules starting with "mixin" in `src/adapters/models`) or directly as submodules of the respective model components. +- The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation. +Please try to integrate any new adapter method into every model class when it's reasonable. +You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html. **Additional things to consider** -- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/transformers/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)). +- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)). - Depending on the method to be integrated, further changes in other classes might be necessary. ## Testing diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 768e4ef082..bd78dee5de 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -78,7 +78,7 @@ "Seq2SeqLMHead", "TaggingHead", ], - "methods.adapter_layer_base": ["AdapterLayerBase"], + "methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"], "model_mixin": [ "EmbeddingAdaptersMixin", "InvertibleAdaptersMixin", @@ -182,7 +182,7 @@ Seq2SeqLMHead, TaggingHead, ) - from .methods.adapter_layer_base import AdapterLayerBase + from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .model_mixin import ( EmbeddingAdaptersMixin, InvertibleAdaptersMixin, diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index ec26142f0f..b89b75cb14 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -78,30 +78,70 @@ def _store_fusion_attentions(self, fusion_name, attentions): @abstractmethod def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + """Adds a new adapter module to the layer. + + Args: + adapter_name (str): The name of the new adapter to add. + layer_idx (int): + The index of the adapters layer (this should be set once by the first added adapter and the kept fix). + + Returns: + bool: True if the adapter was added, False otherwise. + """ raise NotImplementedError() @abstractmethod def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + """Averages a set of adapter modules into a new adapter module. + + Args: + adapter_name (str): The name of the new (averaged) adapter module to add. + input_adapters (Dict[str, float]): Either: + - a list of adapter names (with equal weighting). + - a dictionary of adapter names and their corresponding weights. + + Returns: + bool: True if the adapter was added, False otherwise. + """ raise NotImplementedError() @abstractmethod def delete_adapter(self, adapter_name: str): + """Deletes an adapter module from the layer. + + Args: + adapter_name (str): The name of the adapter to delete. + """ raise NotImplementedError() @abstractmethod def add_fusion_layer(self, adapter_names: Union[List, str]): + # TODO remove this method from the base class raise NotImplementedError() @abstractmethod def delete_fusion_layer(self, adapter_names: Union[List, str]): + # TODO remove this method from the base class raise NotImplementedError() @abstractmethod def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + """Enables/ disables a set of adapter modules within the layer. + + Args: + adapter_setup (AdapterCompositionBlock): The adapter setup to enable/ disable. + unfreeze_adapters (bool): Whether to unfreeze the adapters. + unfreeze_fusion (bool): Whether to unfreeze the fusion layers. + """ raise NotImplementedError() @abstractmethod def get_adapter(self, adapter_name: str) -> nn.Module: + """Returns the adapter module with the given name. + + Args: + adapter_name (str): The name of the adapter module. + """ raise NotImplementedError() diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 8e3330b858..3a8743a3f2 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -254,6 +254,7 @@ class PrefixTuningState(NamedTuple): residual_input (torch.Tensor): The residual input of the attention layer. attention_mask (torch.Tensor, optional): The attention mask of the attention layer. invert_mask (bool): Whether the attention mask is inverted (ie. using '1' for padding). + idx_slice (slice, optional): Id slice for slicing prefix states along the batch size dimension. """