Skip to content

Commit

Permalink
Removing attention mask patching (#791)
Browse files Browse the repository at this point in the history
* Removing attention mask patching

* ruff
  • Loading branch information
echarlaix authored Jul 1, 2024
1 parent eeb1df0 commit 92fe39f
Showing 2 changed files with 2 additions and 110 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_code_quality.yml
Original file line number Diff line number Diff line change
@@ -51,4 +51,4 @@ jobs:
- name: Check style with ruff
run: |
source venv/bin/activate
ruff .
ruff check .
110 changes: 1 addition & 109 deletions optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

import re
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union

import torch
from huggingface_hub import HfApi, HfFolder
@@ -23,114 +23,6 @@
MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"}


# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
device: torch.device,
past_key_values_length: int,
dtype: torch.dtype = torch.bool,
) -> torch.BoolTensor:
"""
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device)
seq_ids = torch.arange(target_length, device=device)

mask[:, past_key_values_length:] = (
(seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min
if torch.is_floating_point(mask)
else seq_ids[:, None] < seq_ids[None, :]
)

return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)


# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask
def _prepare_attn_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
from transformers.models.bloom.modeling_bloom import _expand_mask

# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape

combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)

return combined_attention_mask


# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):
from transformers.models.llama.modeling_llama import _expand_mask

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None

combined_attention_mask = _make_causal_mask(
input_shape,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
dtype=inputs_embeds.dtype,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask


# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask
def _prepare_decoder_sliding_window_attention_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: int,
):
from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None

combined_attention_mask = _make_sliding_window_causal_mask(
input_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask


def get_model_device(model: torch.nn.Module) -> torch.device:
"""
Determines the device on which a PyTorch model is currently residing.

0 comments on commit 92fe39f

Please sign in to comment.