Skip to content

Commit

Permalink
Merge branch 'main' into composability-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Jan 21, 2025
2 parents 2c51f2d + 4b805fe commit 4038e72
Show file tree
Hide file tree
Showing 8 changed files with 788 additions and 12 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ on:
pull_request:
branches: main
types: [ labeled, synchronize ]
push:
branches: main

env:
CADENCE: "commit"
Expand All @@ -15,7 +17,7 @@ env:
jobs:
transformers-tests:
runs-on: gcp-k8s-vllm-l4-solo
if: contains(github.event.pull_request.labels.*.name, 'ready')
if: contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push'
steps:
- uses: actions/setup-python@v5
with:
Expand Down
93 changes: 93 additions & 0 deletions examples/multimodal_vision/phi3_vision_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.utils.data_collator import phi3_vision_data_collator

# Load model.
model_id = "microsoft/Phi-3-vision-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
_attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
processor.chat_template = processor.tokenizer.chat_template

# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


# Apply chat template
def preprocess(example):
messages = [{"role": "user", "content": "<|image_1|>\nWhat does the image show?"}]
return {
"text": processor.apply_chat_template(
messages,
add_generation_prompt=True,
),
"images": example["image"],
}


ds = ds.map(preprocess)


# # Tokenize inputs.
def tokenize(sample):
return processor(
text=sample["text"],
images=sample["images"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)


# long data lengths produced by the phi3_vision processor
# can lead to integer overflows when mapping, avoid with writer_batch_size
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["Phi3DecoderLayer"],
ignore=["lm_head", "re:model.vision_embed_tokens.*"],
),
]

# Perform oneshot
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=phi3_vision_data_collator,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")

# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
123 changes: 123 additions & 0 deletions examples/multimodal_vision/qwen2_vl_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
from io import BytesIO

from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator

# Load model.
model_id = "Qwen/Qwen2-VL-2B-Instruct"
model = TraceableQwen2VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)


# Apply chat template and tokenize inputs.
def preprocess_and_tokenize(example):
# preprocess
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_qwen},
{"type": "text", "text": "What does the image show?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)

# tokenize
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)


ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)

# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["Qwen2VLDecoderLayer"],
ignore=["lm_head", "re:visual.*"],
),
]

# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=qwen2_vl_data_collator,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "http://images.cocodataset.org/train2017/000000231895.jpg",
},
{"type": "text", "text": "Please describe the animal in this image\n"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")


# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
7 changes: 4 additions & 3 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class GPTQModifier(Modifier, HooksMixin):
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- quantize_weight
- on_sequential_batch_end
- quantize_weight
- on_finalize
- remove_hooks()
- model.apply(freeze_module_quantization)
Expand Down Expand Up @@ -191,7 +192,7 @@ def on_initialize_structure(self, state: State, **kwargs):
if self._quantization_modifier:
self._quantization_modifier.on_initialize_structure(state, **kwargs)

def on_initialize(self, state: "State", **kwargs) -> bool:
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the GPTQ algorithm on the current state
Expand Down Expand Up @@ -271,7 +272,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
run_basic(state.model, state.data.calib, self)
return True

def on_finalize(self, state: "State", **kwargs) -> bool:
def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
from .qwen2_vl import (
Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration,
)

__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableMistralForCausalLM",
"TraceableQwen2VLForConditionalGeneration",
]
Loading

0 comments on commit 4038e72

Please sign in to comment.