Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Qwen 2.5 VL #1113

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions examples/multimodal_vision/qwen_2_5_vl_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import base64
from io import BytesIO

import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import (
TraceableQwen2_5_VLForConditionalGeneration,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a quick ELI5 would be nice to have at the top

Suggested change
# The following example shows how the Qwen2.5 Vision-Language model can be W4A16 quantized

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the readme in this folder provides enough context for a user to know how to use this example/what it's doing

# Load model.
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
model = TraceableQwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)


# Apply chat template and tokenize inputs.
def preprocess_and_tokenize(example):
# preprocess
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_qwen},
{"type": "text", "text": "What does the image show?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)

# tokenize
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)


ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["Qwen2_5_VLDecoderLayer"],
ignore=["lm_head", "re:visual.*"],
),
]

# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "http://images.cocodataset.org/train2017/000000231895.jpg",
},
{"type": "text", "text": "Please describe the animal in this image\n"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")


# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
4 changes: 4 additions & 0 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
from .whisper import (
WhisperForConditionalGeneration as TraceableWhisperForConditionalGeneration
)
from .qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration
)

__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableQwen2VLForConditionalGeneration",
"TraceableIdefics3ForConditionalGeneration",
"TraceableWhisperForConditionalGeneration",
"TraceableQwen2_5_VLForConditionalGeneration",
]
Loading