Skip to content

Commit

Permalink
Merge branch 'kylesayrs/hooks-mixin-remove-subsets', remote-tracking …
Browse files Browse the repository at this point in the history
…branch 'origin' into kylesayrs/hooks-mixin-keep
  • Loading branch information
kylesayrs committed Jan 29, 2025
3 parents e3623cc + b61092b + eb83e67 commit 2d6e366
Show file tree
Hide file tree
Showing 8 changed files with 632 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Applying quantization with `llmcompressor`:
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8)
* [Weight only quantization to `int4`](examples/quantization_w4a16)
* [Quantizing MoE LLMs](examples/quantizing_moe)
* [Quantizing Multimodal VLMs](examples/multimodal_vision)

### User Guides
Deep dives into advanced usage of `llmcompressor`:
Expand Down
64 changes: 64 additions & 0 deletions examples/multimodal_vision/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Quantizing Multimodal Vision-Language Models #

<p align="center" style="text-align: center;">
<img src=http://images.cocodataset.org/train2017/000000231895.jpg alt="sample image from MS COCO dataset"/>
</p>
<em>

```
<|system|>
You are a helpful assistant.
<|user|>
Please describe the animal in this image
<|assistant|>
The animal in the image is a white kitten.
It has a fluffy coat and is resting on a white keyboard.
The kitten appears to be comfortable and relaxed, possibly enjoying the warmth of the keyboard.
```
</em>

This directory contains example scripts for quantizing a variety of vision-language models using the GPTQ quantization. Most examples do not demonstrate quantizing separate vision encoder parameters if they exist, as compressing these parameters offers little benefit with respect to performance-accuracy tradeoff.

## Compressing Your Own Model ##
To use your own multimodal modal, start with an existing example change the `model_id` to match your own model stub.
```python3
model_id = "path/to/your/model"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
```

## Customizing GPTQModifier Parameters ##
The GPTQModifier is the modifier responsible for performing quantization of the model weights. For more information on quantizing with different weight schemes, see the `quantization_` examples in the [examples folder](/examples/).

```python3
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["MistralDecoderLayer"],
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
]
```

### Sequential Targets ###
Sequential targets are the modules which determine the granularity of error propagation and activation offloading when performing forward passes of the model. These are typically the "transformer blocks" of the model, also referred to as "layers" with llm-compressor.

Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations.

### Ignore ###
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations.

## Tracing Errors ##
Because the architectures of vision-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).

## Adding Your Own Smoothquant Mappings ##
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md).

## Adding Your Own Data Collator ##
Most examples utilize a generic `data_collator` which correctly correlates data for most multimodal datasets. If you find that your model needs custom data collation (as is the case with [pixtral](/examples/multimodal_vision/pixtral_example.py)), you can modify this function to reflect these model-specific requirements.
117 changes: 117 additions & 0 deletions examples/multimodal_vision/idefics3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import requests
import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration

# Load model.
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
model = TraceableIdefics3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["LlamaDecoderLayer"],
ignore=["re:.*lm_head", "re:model.vision_model.*", "re:model.connector.*"],
),
]

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


# Apply chat template
def preprocess(example):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What does the image show?"},
{"type": "image"},
],
}
]
return {
"text": processor.apply_chat_template(
messages,
add_generation_prompt=True,
),
"images": example["image"],
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return processor(
text=sample["text"],
images=sample["images"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)


# avoid errors with writer_batch_size
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)

# Perform oneshot
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
12 changes: 11 additions & 1 deletion src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ class HooksMixin(BaseModel):
- modifier.remove_hooks()
"""

<<<<<<< .merge_file_6kKUs9
# attached to global HooksMixin class
_HOOKS_DISABLED: ClassVar[bool] = False
_HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set()

# attached to local subclasses
_hooks: Set[RemovableHandle] = set()
=======
_HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin
_hooks: Set[RemovableHandle] = set() # attached to local subclasses
>>>>>>> .merge_file_ehHxqS

@classmethod
@contextlib.contextmanager
Expand Down Expand Up @@ -93,7 +98,12 @@ def wrapped_hook(*args, **kwargs):
return handle

def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
"""Remove all hooks belonging to a modifier"""
"""
Removes hooks registered by this modifier
:param handles: optional list of handles to remove, defaults to all hooks
registerd by this modifier
"""
if handles is None:
handles = self._hooks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ def is_sparse24_bitmask_supported(
return False

if not is_model_quantized(model):
# non-quantized 2:4 sparse models are supported
logger.warning(
"Compressed Sparse-only 2:4 models are not supported in vLLM<=0.7.0, "
"consider saving with `disable_sparse_compression` set, "
"`model.save_pretrained(..., disable_sparse_compression=True)`"
)
return True

# when model is quantized, and has 2:4 sparsity
Expand Down
12 changes: 6 additions & 6 deletions src/llmcompressor/transformers/tracing/GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ a [Sequential Pipeline](/src/llmcompressor/pipelines/sequential/pipeline.py)
is required in order to offload activations and reduce memory usage as well as propagate
the activation error induced by compression.

For example, let's say we want to quantize a basic `3` layer model using the
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py) and `512`
For example, let's say we want to quantize a basic 3 layer model using the
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py) and 512
calibration samples. The [Sequential Pipeline](/src/llmcompressor/pipelines/sequential/pipeline.py)
first identifies each of the layers (sequential targets) within the model. Then, the
pipeline runs each of the `512` samples, one sample at a time, through the first layer.
pipeline runs each of the 512 samples, one sample at a time, through the first layer.
When one sample completes its forward pass through the layer, its activations are
recorded by the [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py)
hessian and the layer output is offloaded to the cpu. After all `512` samples have been
used by the [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py)
to calibrate the hessian and the layer output is offloaded to the cpu. After all 512 samples have been
passed through the layer, the [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py)
uses the recorded activations to compress the weights of the modules within the layer.
Once module compression is complete, the offloaded activations are used to perform the
Expand Down Expand Up @@ -242,7 +242,7 @@ def _prepare_cross_attention_mask(...) -> ...:
<img alt="Wrapped Function" src="assets/wrapped_function.jpg" height="5%" />
</p>
<p align="center">
<em>This image dicts how the internals of the <code>_prepare_cross_attention_mask</code> function are replaced by a single <code>call_module</code> operation, similar to how modules can be ignored as featured in section 1
<em>This image dicts how the internals of the <code>_prepare_cross_attention_mask</code> function are replaced by a single <code>call_module</code> operation, similar to how modules can be ignored as featured in section 1</em>
</p>

Please note that wrapped functions must be defined at the module-level, meaning that
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from .qwen2_vl import (
Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration,
)
from .idefics3 import (
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration
)

__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableQwen2VLForConditionalGeneration",
"TraceableIdefics3ForConditionalGeneration"
]
Loading

0 comments on commit 2d6e366

Please sign in to comment.