diff --git a/examples/multimodal_vision/qwen_2_5_vl_example.py b/examples/multimodal_vision/qwen_2_5_vl_example.py new file mode 100644 index 000000000..8a476fcce --- /dev/null +++ b/examples/multimodal_vision/qwen_2_5_vl_example.py @@ -0,0 +1,132 @@ +import base64 +from io import BytesIO + +import torch +from datasets import load_dataset +from qwen_vl_utils import process_vision_info +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import ( + TraceableQwen2_5_VLForConditionalGeneration, +) + +# Load model. +model_id = "Qwen/Qwen2.5-VL-7B-Instruct" +model = TraceableQwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype="auto", +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42) + + +# Apply chat template and tokenize inputs. +def preprocess_and_tokenize(example): + # preprocess + buffered = BytesIO() + example["image"].save(buffered, format="PNG") + encoded_image = base64.b64encode(buffered.getvalue()) + encoded_image_text = encoded_image.decode("utf-8") + base64_qwen = f"data:image;base64,{encoded_image_text}" + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": base64_qwen}, + {"type": "text", "text": "What does the image show?"}, + ], + } + ] + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + + # tokenize + return processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + ) + + +ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names) + + +# Define a oneshot data collator for multimodal inputs. +def data_collator(batch): + assert len(batch) == 1 + return {key: torch.tensor(value) for key, value in batch[0].items()} + + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + sequential_targets=["Qwen2_5_VLDecoderLayer"], + ignore=["lm_head", "re:visual.*"], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "http://images.cocodataset.org/train2017/000000231895.jpg", + }, + {"type": "text", "text": "Please describe the animal in this image\n"}, + ], + } +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=[prompt], + images=image_inputs, + videos=video_inputs, + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + return_tensors="pt", +).to("cuda") +output = model.generate(**inputs, max_new_tokens=100) +print(processor.decode(output[0], skip_special_tokens=True)) +print("==========================================") + + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 39410a1ef..d5643efb8 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -13,6 +13,9 @@ from .whisper import ( WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration ) +from .qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration +) __all__ = [ "TraceableLlavaForConditionalGeneration", @@ -20,4 +23,5 @@ "TraceableQwen2VLForConditionalGeneration", "TraceableIdefics3ForConditionalGeneration", "TraceableWhisperForConditionalGeneration", + "TraceableQwen2_5_VLForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/qwen2_5_vl.py b/src/llmcompressor/transformers/tracing/qwen2_5_vl.py new file mode 100644 index 000000000..457adf1fb --- /dev/null +++ b/src/llmcompressor/transformers/tracing/qwen2_5_vl.py @@ -0,0 +1,550 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# vllm-project: no copyright +"""PyTorch Qwen2-5-VL model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLModel, + Qwen2_5_VLForConditionalGeneration, +) +from transformers.utils.fx import HFProxy + + +# TRACING: cannot iterate input ids +@torch.fx.wrap +def get_rope_index( + config: Qwen2_5_VLConfig, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + vision_start_token_id = config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + +# TRACING: shape after get_rope_index is known +def maybe_install_metadata_position_ids( + position_ids: Union[torch.Tensor, HFProxy], + input_ids: Union[torch.Tensor, HFProxy], +): + if isinstance(position_ids, HFProxy): + batch_size, sequence_length = input_ids._metadata.shape + + original_fn = position_ids.tracer.patched_torch_methods["empty"][1] + metadata = original_fn( + (3, batch_size, sequence_length), device=torch.device("meta") + ) + position_ids.install_metadata(metadata) + + return position_ids + + +# TRACING: cannot condition on mask shape +@torch.fx.wrap +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, +): + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + # TRACING: cannot condition on mask shape + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Qwen2_5_VLModel(Qwen2_5_VLModel): + # TRACING: needs to use wrapped _prepare_4d_causal_attention_mask_with_cache_position + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + # TRACING: Assume an SDPA causal mask + # if AttentionMaskConverter._ignore_causal_mask_sdpa( + # attention_mask, + # inputs_embeds=input_tensor, + # past_key_values_length=past_seen_tokens, + # sliding_window=self.config.sliding_window, + # is_training=self.training, + # ): + if False: + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + # TRACING: use wrapped _prepare_4d_causal_attention_mask_with_cache_position + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + # TRACING: use Qwen2VLModel with wrapped _prepare_4d_causal_attention_mask_with_cache_position + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + # TRACING: assume processing was done correctly + #if n_image_tokens != n_image_features: + if False: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + # TRACING: use wrapped function + position_ids, rope_deltas = get_rope_index( + self.config, + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + # TRACING: the position_ids shape is known + position_ids = maybe_install_metadata_position_ids(position_ids, input_ids) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + # TRACING: use wrapped _prepare_4d_causal_attention_mask_with_cache_position + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs diff --git a/src/llmcompressor/transformers/tracing/qwen2_vl.py b/src/llmcompressor/transformers/tracing/qwen2_vl.py index 60aa4f51c..1bbd28d68 100644 --- a/src/llmcompressor/transformers/tracing/qwen2_vl.py +++ b/src/llmcompressor/transformers/tracing/qwen2_vl.py @@ -27,11 +27,7 @@ import torch.nn as nn import torch.utils.checkpoint from torch.nn import CrossEntropyLoss -from transformers.cache_utils import ( - Cache, - SlidingWindowCache, - StaticCache, -) +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig @@ -47,7 +43,7 @@ # TRACING: cannot iterate input ids @torch.fx.wrap def get_rope_index( - config, + config: Qwen2VLConfig, input_ids: torch.LongTensor, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,