diff --git a/conftest.py b/conftest.py index 93673cd20..6e6ad8e7b 100644 --- a/conftest.py +++ b/conftest.py @@ -46,7 +46,11 @@ def pytest_configure(config): config.addinivalue_line( "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" ) + config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") + config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") + config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule") + config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu") def pytest_addoption(parser): diff --git a/hf_transformers b/hf_transformers index 53fad641c..052e652d6 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 53fad641cfdb5105e2470bcf3ef17ea8e25cc300 +Subproject commit 052e652d6d53c2b26ffde87e039b723949a53493 diff --git a/pyproject.toml b/pyproject.toml index ad2437e9a..3dca5b20d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,11 @@ [tool.black] line-length = 119 target-version = ['py38', 'py39', 'py310'] + +# copied from HF for testing +[tool.pytest.ini_options] +markers = [ + "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", + "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", + "generate: marks tests that use the GenerationTesterMixin" +] diff --git a/setup.py b/setup.py index 6fd87f9ca..e7389c8be 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.45.2", + "transformers~=4.46.3", ] diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 1e3e0760d..6f419744b 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -705,6 +705,14 @@ }, "layers": [None, "score"], }, + "MistralForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "qa_outputs"], + }, # Electra "ElectraForTokenClassification": { "config": { diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 13a0d8d35..ced4ff075 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -139,10 +139,8 @@ def tie_weights(self): super().tie_weights() - def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): - old_embeddings = self.get_input_embeddings() - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) - self.set_input_embeddings(new_embeddings) + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): + super()._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) # if word embeddings are not tied, make sure that lm head is resized as well if not self.config.tie_word_embeddings: diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index cd18e8165..5e5cb960f 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -25,13 +25,26 @@ import torch from torch import nn -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock +from transformers.models.distilbert.modeling_distilbert import ( + DistilBertFlashAttention2, + DistilBertSdpaAttention, + MultiHeadSelfAttention, + TransformerBlock, +) +from transformers.utils import is_flash_attn_2_available, logging from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from ...utils import prefix_attention_mask from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + class MultiHeadSelfAttentionWithAdapters(DistilBertMultiHeadSelfAttentionMixin, MultiHeadSelfAttention): def forward( self, @@ -66,18 +79,20 @@ def shape(x: torch.Tensor) -> torch.Tensor: def unshape(x: torch.Tensor) -> torch.Tensor: """group heads""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) + return x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_heads * dim_per_head) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + # >>> START AH Changes <<< q, k, v = match_attn_matrices_for_parallel(q, k, v) (mask,) = adjust_tensors_for_parallel(q, mask) k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False) bs = k.size(0) # reset for Parallel block (q,) = adjust_tensors_for_parallel(k, q) + # >>> END AH Changes <<< mask_reshp = (bs, 1, 1, k.size(2)) @@ -105,6 +120,172 @@ def unshape(x: torch.Tensor) -> torch.Tensor: return (context,) +class DistilBertSdpaAttentionWithAdapters(DistilBertMultiHeadSelfAttentionMixin, DistilBertSdpaAttention): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + if output_attentions or head_mask is not None: + logger.warning_once( + "DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support" + " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying" + " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be" + ' removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query, + key, + value, + mask, + head_mask, + output_attentions, + ) + + batch_size, _, _ = query.size() + dim_per_head = self.dim // self.n_heads + + def shape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + # keep first dim due to parallel composition + return x.view(x.shape[0], -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x: torch.Tensor) -> torch.Tensor: + """group heads""" + return x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + + # >>> START AH Changes <<< + q, k, v = match_attn_matrices_for_parallel(q, k, v) + (mask,) = adjust_tensors_for_parallel(q, mask) + + k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False) + (q,) = adjust_tensors_for_parallel(k, q) + # >>> END AH Changes <<< + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None: + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=False, + ) + + attn_output = unshape(attn_output) + attn_output = self.out_lin(attn_output) + + return (attn_output,) + + +class DistilBertFlashAttention2WithAdapters(DistilBertMultiHeadSelfAttentionMixin, DistilBertFlashAttention2): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(x.shape[0], -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + q_length, + dropout=attn_dropout, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + class TransformerBlockWithAdapters(DistilBertTransfomerBlockAdaptersMixin, TransformerBlock): def forward( self, @@ -123,7 +304,7 @@ def forward( torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ adjust_tensors_for_parallel_(x, attn_mask) - attn_mask = prefix_attention_mask(attn_mask, dim=1, prefix_value=1) # type: ignore + attn_mask = prefix_attention_mask(attn_mask, dim=[2, 3], prefix_value=1) # type: ignore # Self-Attention sa_output = self.attention( diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py index 5bbd1ba1c..05141a08c 100644 --- a/src/adapters/models/mt5/modeling_mt5.py +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -16,8 +16,8 @@ import torch from torch import nn -from torch.utils.checkpoint import checkpoint +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.mt5.modeling_mt5 import ( MT5Attention, @@ -26,7 +26,7 @@ MT5LayerSelfAttention, MT5Stack, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..t5.mixin_t5 import ( @@ -52,6 +52,7 @@ def forward(self, hidden_states): class MT5AttentionWithAdapters(T5AttentionAdaptersMixin, MT5Attention): + def forward( self, hidden_states, @@ -63,105 +64,99 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + # >>> START AH Changes <<< + # adapt bsz for lora parallel + query_states = query_states.view(query_states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) + # >>> END AH Changes <<< if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - # keep first dim due to parallel composition - return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(key_states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + value_states.shape[0], -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + # >>> START AH Changes <<< query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) (mask,) = adjust_tensors_for_parallel(query_states, mask) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) batch_size, key_length = key_states.shape[0], key_states.shape[2] - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: + # >>> END AH Changes <<< + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + # >>> START AH Changes <<< + # For Prefix Tuning, when training with AdapterDrop, we must additionally check that the sequence lengths of + # both positional encoding and the scores account for the prefix tokens. + # This is because the positional encoding is calculated only once in the beginning and then used for all layers. + # However, if the encoding was calculated without the prefix tokens due to AdapterDrop having dropped an + # adapter layer in the beginning, the positional encoding will be shorter than the scores, resulting in a + # dimension mismatch when adding the positional encoding to the scores. + if position_bias is None or position_bias.shape[3] != scores.shape[3]: + # >>> END AH Changes <<< + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -171,21 +166,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -202,6 +198,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -212,6 +209,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None @@ -232,6 +230,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -244,6 +243,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None @@ -267,6 +267,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -278,10 +279,12 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # >>> START AH Changes <<< if self.is_decoder and encoder_hidden_states is not None: input_ids, encoder_attention_mask = adjust_tensors_for_parallel( encoder_hidden_states, input_ids, encoder_attention_mask ) + # >>> END AH Changes <<< if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" @@ -297,6 +300,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -304,28 +314,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -333,22 +372,16 @@ def forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -356,18 +389,20 @@ def forward( encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) + # >>> START AH Changes <<< if not self.is_decoder: hidden_states = self.post_embedding_forward(hidden_states) + # >>> END AH Changes <<< - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -384,17 +419,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -402,20 +430,26 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -423,11 +457,11 @@ def custom_forward(*inputs): if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] - attention_mask, extended_attention_mask = adjust_tensors_for_parallel( - hidden_states, attention_mask, extended_attention_mask - ) + # >>> START AH Changes <<< + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) + # >>> END AH Changes <<< # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -435,16 +469,15 @@ def custom_forward(*inputs): position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) + # >>> START AH Changes <<< if position_bias is not None: position_bias = adjust_tensors_for_parallel(hidden_states, position_bias)[0] if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = adjust_tensors_for_parallel( hidden_states, encoder_decoder_position_bias )[0] + # >>> END AH Changes <<< if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -464,12 +497,18 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -478,7 +517,7 @@ def custom_forward(*inputs): ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index d745cf40c..09b969bb1 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -16,8 +16,8 @@ import torch from torch import nn -from torch.utils.checkpoint import checkpoint +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.t5.modeling_t5 import ( T5Attention, @@ -26,7 +26,7 @@ T5LayerSelfAttention, T5Stack, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from .mixin_t5 import ( @@ -52,6 +52,7 @@ def forward(self, hidden_states): class T5AttentionWithAdapters(T5AttentionAdaptersMixin, T5Attention): + def forward( self, hidden_states, @@ -63,88 +64,73 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + # >>> START AH Changes <<< + # adapt bsz for lora parallel + query_states = query_states.view(query_states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) + # >>> END AH Changes <<< if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - # keep first dim due to parallel composition - return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(key_states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + value_states.shape[0], -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + # >>> START AH Changes <<< query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) (mask,) = adjust_tensors_for_parallel(query_states, mask) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) batch_size, key_length = key_states.shape[0], key_states.shape[2] + # >>> END AH Changes <<< + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - + # >>> START AH Changes <<< # For Prefix Tuning, when training with AdapterDrop, we must additionally check that the sequence lengths of # both positional encoding and the scores account for the prefix tokens. # This is because the positional encoding is calculated only once in the beginning and then used for all layers. @@ -152,22 +138,25 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # adapter layer in the beginning, the positional encoding will be shorter than the scores, resulting in a # dimension mismatch when adding the positional encoding to the scores. if position_bias is None or position_bias.shape[3] != scores.shape[3]: + # >>> END AH Changes <<< + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -177,21 +166,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -208,6 +198,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -218,6 +209,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None @@ -238,6 +230,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -250,6 +243,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None @@ -273,6 +267,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -284,10 +279,12 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # >>> START AH Changes <<< if self.is_decoder and encoder_hidden_states is not None: input_ids, encoder_attention_mask = adjust_tensors_for_parallel( encoder_hidden_states, input_ids, encoder_attention_mask ) + # >>> END AH Changes <<< if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" @@ -303,6 +300,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -310,28 +314,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -339,22 +372,16 @@ def forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -362,18 +389,20 @@ def forward( encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) + # >>> START AH Changes <<< if not self.is_decoder: hidden_states = self.post_embedding_forward(hidden_states) + # >>> END AH Changes <<< - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -390,17 +419,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -408,20 +430,26 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -429,11 +457,11 @@ def custom_forward(*inputs): if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] - attention_mask, extended_attention_mask = adjust_tensors_for_parallel( - hidden_states, attention_mask, extended_attention_mask - ) + # >>> START AH Changes <<< + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) + # >>> END AH Changes <<< # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -441,16 +469,15 @@ def custom_forward(*inputs): position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) + # >>> START AH Changes <<< if position_bias is not None: position_bias = adjust_tensors_for_parallel(hidden_states, position_bias)[0] if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = adjust_tensors_for_parallel( hidden_states, encoder_decoder_position_bias )[0] + # >>> END AH Changes <<< if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -470,12 +497,18 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -484,7 +517,7 @@ def custom_forward(*inputs): ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, diff --git a/tests/test_clip.py b/tests/test_clip.py index 704c7a164..30be353f7 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -33,7 +33,7 @@ class CLIPVisionAdapterTestBase(VisionAdapterTestBase): config_class = CLIPVisionConfig config = make_config( CLIPVisionConfig, - image_size=30, + image_size=224, hidden_size=32, num_hidden_layers=4, num_attention_heads=4, @@ -64,7 +64,7 @@ class CLIPVisionWithProjectionAdapterTestBase(VisionAdapterTestBase): config_class = CLIPVisionConfig config = make_config( CLIPVisionConfig, - image_size=30, + image_size=224, hidden_size=32, num_hidden_layers=4, num_attention_heads=4, @@ -161,7 +161,7 @@ class CLIPAdapterTestBase(AdapterTestBase): intermediate_size=37, ), CLIPVisionConfig( - image_size=30, + image_size=224, hidden_size=32, num_hidden_layers=4, num_attention_heads=4,