From a9152e78b099ff55058d6d91198dbc4962b509a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sat, 6 Apr 2024 17:30:37 +0200 Subject: [PATCH] Upgrade Transformers to v4.38.x (#654) Changes: - HF changed parts of the Llama model implementation - HF added a `LlamaForQuestionAnswering`. However, this model has a wrong base model name. I added a workaround that solves this problem until this is fixed in Transformers (https://github.com/huggingface/transformers/pull/29258) --------- Co-authored-by: calpt --- docs/adapter_composition.md | 4 +- docs/classes/models/auto.rst | 3 + docs/classes/models/bart.rst | 2 +- docs/classes/models/electra.rst | 2 +- docs/classes/models/llama.rst | 5 + docs/conf.py | 2 +- docs/index.rst | 1 + docs/methods.md | 2 +- docs/prediction_heads.md | 12 +- docs/quickstart.md | 3 +- docs/training.md | 2 +- hf_transformers | 2 +- setup.py | 12 +- src/adapters/head_utils.py | 12 ++ src/adapters/hub_mixin.py | 3 +- src/adapters/models/__init__.py | 3 +- src/adapters/models/bart/adapter_model.py | 4 +- src/adapters/models/llama/adapter_model.py | 3 + src/adapters/models/llama/mixin_llama.py | 6 + src/adapters/models/llama/modeling_llama.py | 123 ++++++++------------ src/adapters/utils.py | 4 +- tests/test_llama.py | 3 +- 22 files changed, 107 insertions(+), 106 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 0851e7970f..5ff2d4284f 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -19,7 +19,7 @@ model.active_adapters = "adapter_name" - You cannot activate an adapter before previously adding it to the model using either ``add_adapter()`` or ``load_adapter()``. - All adapters not mentioned in the ``active_adapters`` setup are ignored, although they might have been loaded into the model. Thus, after adding an adapter, make sure to activate it. ``` -Note that we also could have used the [`set_active_adapters`](adapters.) method with `model.set_active_adapters("adapter_name")` which does the same. +Note that we also could have used the `set_active_adapters` method with `model.set_active_adapters("adapter_name")` which does the same. Alternatively, the [`AdapterSetup`](adapters.AdapterSetup) context manager allows dynamic configuration of activated setups without changing the model state: @@ -125,7 +125,7 @@ model.active_adapters = ac.Fuse("d", "e", "f") To learn how training an _AdapterFusion_ layer works, check out [this Colab notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) from the `adapters` repo. -#### Retrieving AdapterFusion attentions +### Retrieving AdapterFusion attentions Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model. These scores can be used for analyzing the fused adapter blocks and can serve as the basis for visualizations similar to those in the AdapterFusion paper. diff --git a/docs/classes/models/auto.rst b/docs/classes/models/auto.rst index f4081a77c4..a276854894 100644 --- a/docs/classes/models/auto.rst +++ b/docs/classes/models/auto.rst @@ -4,6 +4,9 @@ Auto Classes Similar to the ``AutoModel`` classes built-in into HuggingFace Transformers, adapters provides an ``AutoAdapterModel`` class. As with other auto classes, the correct adapter model class is automatically instantiated based on the pre-trained model passed to the ``from_pretrained()`` method. +.. note:: + If the model loaded with the ``from_pretrained(...)`` function has a head, this head gets loaded as well. However, this only works for non-sharded models. If you want to load a sharded model with a head, you first need to load the model and then the head separately. + AutoAdapterModel ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/classes/models/bart.rst b/docs/classes/models/bart.rst index 5ea8eeb11f..67a5e56572 100644 --- a/docs/classes/models/bart.rst +++ b/docs/classes/models/bart.rst @@ -22,4 +22,4 @@ BartAdapterModel .. autoclass:: adapters.BartAdapterModel :members: - :inherited-members: BartPretrainedModel + :inherited-members: BartPreTrainedModel diff --git a/docs/classes/models/electra.rst b/docs/classes/models/electra.rst index d67a96d8d5..e0dc9c5ef4 100644 --- a/docs/classes/models/electra.rst +++ b/docs/classes/models/electra.rst @@ -1,5 +1,5 @@ ELECTRA -====== +======= The ELECTRA model was proposed in the paper `ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators `__. ELECTRA is a new pretraining approach which trains two diff --git a/docs/classes/models/llama.rst b/docs/classes/models/llama.rst index c7fffe1834..f650f93225 100644 --- a/docs/classes/models/llama.rst +++ b/docs/classes/models/llama.rst @@ -1,6 +1,11 @@ LLaMA ----------------------------------------------------------------------------------------------------------------------- +.. note:: + Loading a ``LlamaForQuestionAnswering`` via [`AutoAdapterModel`](adapters.AutoAdapterModel) or via [`LlamaAdapterModel`](adapters.LlamaAdapterModel) does not load the head, even if the model is not sharded. Please load the base model first and then subsequently the head. + Note that for sharded models the head is never automatically loaded as described here: [Auto Classes](auto.rst) + + The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models `__ by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. It is a collection of foundation language diff --git a/docs/conf.py b/docs/conf.py index 417623a94b..746e77d4a9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -90,7 +90,7 @@ def skip_head_member(app, what, name, obj, skip, options): if type(obj).__name__ == "function" and "inherited-members" in options and (m := re.match(r"add\_(.*)\_head$", name)): - cls_name = options["inherited-members"].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel") + cls_name = list(options["inherited-members"])[0].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel") cls = vars(sys.modules["adapters"])[cls_name] # HACK: currently parses head type from name head_type_str = m.group(1).replace("qa", "question_answering") diff --git a/docs/index.rst b/docs/index.rst index 4d13a1942d..5f87f7ae1e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -82,6 +82,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/gptj classes/models/llama classes/models/mbart + classes/models/mt5 classes/models/roberta classes/models/t5 classes/models/vit diff --git a/docs/methods.md b/docs/methods.md index 06ad700e68..535b23d088 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -1,7 +1,7 @@ # Adapter Methods On this page, we present all adapter methods currently integrated into the `adapters` library. -A tabular overview of adapter methods is provided [here](overview.html#table-of-adapter-methods). +A tabular overview of adapter methods is provided [here](overview.md#table-of-adapter-methods). Additionally, options to combine multiple adapter methods in a single setup are presented [on the next page](method_combinations.md). ## Bottleneck Adapters diff --git a/docs/prediction_heads.md b/docs/prediction_heads.md index 33786385d6..eba5079024 100644 --- a/docs/prediction_heads.md +++ b/docs/prediction_heads.md @@ -6,7 +6,7 @@ We will take a look at the `AdapterModel` classes (e.g. `BertAdapterModel`) intr ```{eval-rst} .. tip:: We recommend to use the `AdapterModel classes <#adaptermodel-classes>`_ whenever possible. - They have been created specifically for working with adapters and provide more flexibility. + These **flexible** models have been created specifically for working with adapters. ``` ## AdapterModel classes @@ -18,16 +18,14 @@ First, we load pre-trained model from the Hugging Face Hub via the [`AutoAdapter model = AutoAdapterModel.from_pretrained("bert-base-uncased") ``` -By default, this model doesn't have any heads yet. We add a new one in the next step: +By default, this model doesn't have any heads yet, so let's add a new binary sequence classification head on top of our model: ```python model.add_classification_head("mrpc", num_labels=2) ``` -The line above adds a binary sequence classification head on top of our model. -Because this head is named, we could add multiple other heads with different names to the same model. -This is especially useful if used together with matching adapter modules. -To learn more about the different head types and the configuration options, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel). +All heads have a name, we called this new head `"mrpc"`. Since all heads are named, we can add multiple other heads with different names to the same model. +To see the head types of a model and how they can get configured, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel). -Now, of course, we would like to train our classification head together with an adapter, so let's add one: +A head alone is just one layer with very few parameters. Hence, we want to train our classification head together with an adapter, so let's add one: ```python model.add_adapter("mrpc", config="seq_bn") model.set_active_adapters("mrpc") diff --git a/docs/quickstart.md b/docs/quickstart.md index c1181f1d28..4d64b51e41 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -120,4 +120,5 @@ model.delete_adapter(adapter_name) _We also have a Quickstart Colab notebook for adapter training:_ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) -For more examples on training different adapter setups, refer to the section on [Adapter Training](training.md). +For more examples of training different adapter setups, refer to the section on [Adapter Training](training.md). +Further information on using adapters with prediction heads can be found in the [Prediction Heads](prediction_heads.md) section. diff --git a/docs/training.md b/docs/training.md index 649ace6e28..95ffeb2f1e 100644 --- a/docs/training.md +++ b/docs/training.md @@ -84,7 +84,7 @@ model.set_active_adapters(task_name) ### Step D - Switch to `AdapterTrainer` class -Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](transformers.adapters.AdapterTrainer) class that is optimized for training adapter methods. +Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](adapters.trainer.AdapterTrainer) class that is optimized for training adapter methods. See [below for more information](#adaptertrainer). Technically, this change is not required as no changes to the training loop are required for training adapters. diff --git a/hf_transformers b/hf_transformers index a7cab3c283..a0857740c0 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit a7cab3c283312b8d4de5df3bbe719971e24f4281 +Subproject commit a0857740c0e6127485c11476650314df3accc2b6 diff --git a/setup.py b/setup.py index eff6de831c..84beca00ae 100644 --- a/setup.py +++ b/setup.py @@ -51,16 +51,16 @@ "sacremoses", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", - "sphinx-copybutton", - "sphinx-markdown-tables", + "sphinx-copybutton==0.5.2", + "sphinx-markdown-tables==0.0.17", "sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style. - "sphinx==3.2.1", + "sphinx==5.0.2", "sphinxext-opengraph==0.4.1", - "sphinx-intl", - "sphinx-multiversion", + "sphinx-intl==2.1.0", + "sphinx-multiversion==0.2.4", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.36.0", + "transformers~=4.38.1", ] diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 2144fbe5ee..ec78430e02 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -498,6 +498,7 @@ }, "layers": [None, "qa_outputs"], }, + # T5 "T5ForConditionalGeneration": { "config": { "head_type": "seq2seq_lm", @@ -526,6 +527,7 @@ "classification_head.out_proj", ], }, + # DeBERTaV2 "DebertaV2ForSequenceClassification": { "config": { "head_type": "classification", @@ -575,6 +577,7 @@ }, "layers": [None, "pooler.dense", None, None, "classifier"], }, + # DeBERTa "DebertaForSequenceClassification": { "config": { "head_type": "classification", @@ -641,6 +644,15 @@ }, "layers": ["lm_head"], }, + "LlamaForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "qa_outputs"], + }, + # Electra "ElectraForTokenClassification": { "config": { "head_type": "tagging", diff --git a/src/adapters/hub_mixin.py b/src/adapters/hub_mixin.py index ece00238a1..7a1009c5b8 100644 --- a/src/adapters/hub_mixin.py +++ b/src/adapters/hub_mixin.py @@ -70,7 +70,8 @@ def _save_adapter_card( metrics: Optional[List[str]] = None, **kwargs ): - all_tags = {"adapter-transformers"} # TODO: change this tag once changed on HF side + # Key remains "adapter-transformers", see: https://github.com/huggingface/huggingface.js/pull/459 + all_tags = {"adapter-transformers"} datasets = set() # Dataset/ Task info dataset_name = None diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 46eba733b7..ff19c38f3b 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -17,7 +17,7 @@ from .distilbert.mixin_distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin -from .llama.mixin_llama import LlamaModelAdapterMixin +from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, @@ -83,4 +83,5 @@ "BertGenerationEncoder": BertModelAdaptersMixin, "BertGenerationLayer": BertLayerAdaptersMixin, "LlamaModel": LlamaModelAdapterMixin, + "LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin, } diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index dec5a838c2..384955cc11 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -5,7 +5,7 @@ BART_START_DOCSTRING, BartConfig, BartModel, - BartPretrainedModel, + BartPreTrainedModel, shift_tokens_right, ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -18,7 +18,7 @@ @add_start_docstrings( "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) -class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): +class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPreTrainedModel): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 97cc0c4e3f..16bea405d4 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import torch @@ -58,6 +59,7 @@ def forward( past_key_values=None, inputs_embeds=None, use_cache=None, + cache_position: Optional[torch.LongTensor] = None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -79,6 +81,7 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, + cache_position=cache_position, output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=output_hidden_states, diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index aae339433c..4f15c0d948 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -44,3 +44,9 @@ def post_embedding_forward(self, module, args, embedding_output): embedding_output = self.invertible_adapters_forward(embedding_output) # Prompt tuning not yet supported return embedding_output + + +class LlamaForQuestionAnsweringAdapterMixin: + # this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix + # TODO: remove this when the inconsistency is fixed and remove the LlamaForQuestionAnsweringAdapterMixin from `src/adapters/models/__init__.py` + base_model_prefix = "transformer" diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index baefac75a9..752cfccb10 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -27,11 +27,7 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import ( - adjust_tensors_for_parallel, - adjust_tensors_for_parallel_, - match_attn_matrices_for_parallel, -) +from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging @@ -53,14 +49,9 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -94,20 +85,13 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - "The cache structure has changed since version v4.36. If you are using" - f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" - " initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -120,22 +104,12 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -174,18 +148,9 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -206,23 +171,21 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -241,8 +204,10 @@ def forward( input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype @@ -281,6 +246,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -297,6 +263,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -314,15 +281,14 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -333,19 +299,15 @@ def forward( ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -354,14 +316,12 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -377,12 +337,15 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -391,8 +354,12 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" + ) - adjust_tensors_for_parallel_(hidden_states, attention_mask, position_ids) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -405,6 +372,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = self.attention_adapters(hidden_states, residual, None) diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 0e3b20cabe..1dbe06cf24 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -837,8 +837,8 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not masked tokens. This inversion is usually done in the forward method of the model in 2 different ways: - 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT - does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` + 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT + does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` """ forward_context = ForwardContext.get_context() diff --git a/tests/test_llama.py b/tests/test_llama.py index 9cb6fcfda2..e8cf0557a0 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -61,4 +61,5 @@ class LlamaClassConversionTest( LlamaAdapterTestBase, unittest.TestCase, ): - pass + def test_conversion_question_answering_model(self): + raise self.skipTest("We don't support the Llama QA model.")