-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Audio] Whisper Example and Readme (#1106)
## Purpose ## * Show example of quantizing whisper audio model ## Changes ## * Add whisper audio model example * Add traceable whisper definition (only need to comment out a value error check) * The embedded audio is achieved using [github attached files](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/attaching-files). While there's no official word on how long these files are maintained, if it is found that the file is deleted at some point, then we can replace it with a link to the file uploaded to the repo. ## Testing ## Successfully quantized whisper models and generated reasonable sample outputs * https://huggingface.co/nm-testing/whisper-tiny-W4A16-G128 * https://huggingface.co/nm-testing/whisper-large-v2-W4A16-G128 --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
- Loading branch information
Showing
7 changed files
with
393 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Quantizing Multimodal Audio Models # | ||
|
||
https://github.com/user-attachments/assets/6732c60b-1ebe-4bed-b409-c16c4415dff5 | ||
|
||
Audio provided by Daniel Galvez et al. under creative commons license | ||
|
||
``` | ||
<|startoftranscript|> <|en|> | ||
... | ||
<|transcribe|> <|notimestamps|> | ||
that's where you have a lot of windows in the south no actually that's passive solar | ||
and passive solar is something that was developed and designed in the 1960s and 70s | ||
and it was a great thing for what it was at the time but it's not a passive house | ||
``` | ||
</em> | ||
|
||
This directory contains example scripts for quantizing a variety of audio language models using the GPTQ quantization. | ||
|
||
## Compressing Your Own Model ## | ||
To use your own multimodal modal, start with an existing example change the `model_id` to match your own model stub. | ||
```python3 | ||
model_id = "path/to/your/model" | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
device_map="auto", | ||
torch_dtype="auto", | ||
) | ||
``` | ||
|
||
## Customizing GPTQModifier Parameters ## | ||
The GPTQModifier is the modifier responsible for performing quantization of the model weights. For more information on quantizing with different weight schemes, see the `quantization_` examples in the [examples folder](/examples/). | ||
|
||
```python3 | ||
recipe = [ | ||
GPTQModifier( | ||
targets="Linear", | ||
scheme="W4A16", | ||
sequential_targets=["WhisperEncoderLayer", "WhisperDecoderLayer"], | ||
ignore=["lm_head"], | ||
) | ||
] | ||
``` | ||
|
||
### Sequential Targets ### | ||
Sequential targets are the modules which determine the granularity of error propagation and activation offloading when performing forward passes of the model. These are typically the "transformer blocks" of the model, also referred to as "layers" with llm-compressor. | ||
|
||
Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations. | ||
|
||
### Ignore ### | ||
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations. | ||
|
||
## Tracing Errors ## | ||
Because the architectures of audio-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md). | ||
|
||
## Adding Your Own Smoothquant Mappings ## | ||
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md). | ||
|
||
## Adding Your Own Data Collator ## | ||
Most examples utilize a generic `data_collator` which correctly correlates data for most multimodal datasets. If you find that your model needs custom data collation (as is the case with [pixtral](/examples/multimodal_vision/pixtral_example.py)), you can modify this function to reflect these model-specific requirements. | ||
|
||
## Sample Audio Provided Under a Creative Commons Attribution License ## | ||
https://creativecommons.org/licenses/by/4.0/legalcode | ||
``` | ||
@article{DBLP:journals/corr/abs-2111-09344, | ||
author = {Daniel Galvez and | ||
Greg Diamos and | ||
Juan Ciro and | ||
Juan Felipe Cer{\'{o}}n and | ||
Keith Achorn and | ||
Anjali Gopi and | ||
David Kanter and | ||
Maximilian Lam and | ||
Mark Mazumder and | ||
Vijay Janapa Reddi}, | ||
title = {The People's Speech: {A} Large-Scale Diverse English Speech Recognition | ||
Dataset for Commercial Usage}, | ||
journal = {CoRR}, | ||
volume = {abs/2111.09344}, | ||
year = {2021}, | ||
url = {https://arxiv.org/abs/2111.09344}, | ||
eprinttype = {arXiv}, | ||
eprint = {2111.09344}, | ||
timestamp = {Mon, 22 Nov 2021 16:44:07 +0100}, | ||
biburl = {https://dblp.org/rec/journals/corr/abs-2111-09344.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
from datasets import load_dataset | ||
from transformers import WhisperProcessor | ||
|
||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier | ||
from llmcompressor.transformers import oneshot | ||
from llmcompressor.transformers.tracing import TraceableWhisperForConditionalGeneration | ||
|
||
# Select model and load it. | ||
MODEL_ID = "openai/whisper-large-v2" | ||
|
||
model = TraceableWhisperForConditionalGeneration.from_pretrained( | ||
MODEL_ID, | ||
device_map="auto", | ||
torch_dtype="auto", | ||
) | ||
model.config.forced_decoder_ids = None | ||
processor = WhisperProcessor.from_pretrained(MODEL_ID) | ||
|
||
# Configure processor the dataset task. | ||
processor.tokenizer.set_prefix_tokens(language="en", task="transcribe") | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "MLCommons/peoples_speech" | ||
DATASET_SUBSET = "test" | ||
DATASET_SPLIT = "test" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset( | ||
DATASET_ID, | ||
DATASET_SUBSET, | ||
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]", | ||
trust_remote_code=True, | ||
) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"array": example["audio"]["array"], | ||
"sampling_rate": example["audio"]["sampling_rate"], | ||
"text": " " + example["text"].capitalize(), | ||
} | ||
|
||
|
||
ds = ds.map(preprocess, remove_columns=ds.column_names) | ||
|
||
|
||
# Process inputs. | ||
def process(sample): | ||
audio_inputs = processor( | ||
audio=sample["array"], | ||
sampling_rate=sample["sampling_rate"], | ||
return_tensors="pt", | ||
) | ||
|
||
text_inputs = processor( | ||
text=sample["text"], add_special_tokens=True, return_tensors="pt" | ||
) | ||
text_inputs["decoder_input_ids"] = text_inputs["input_ids"] | ||
del text_inputs["input_ids"] | ||
|
||
return dict(**audio_inputs, **text_inputs) | ||
|
||
|
||
ds = ds.map(process, remove_columns=ds.column_names) | ||
|
||
|
||
# Define a oneshot data collator for multimodal inputs. | ||
def data_collator(batch): | ||
assert len(batch) == 1 | ||
return {key: torch.tensor(value) for key, value in batch[0].items()} | ||
|
||
|
||
# Recipe | ||
recipe = [ | ||
SmoothQuantModifier(smoothing_strength=0.8), | ||
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), | ||
] | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
data_collator=data_collator, | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
sample_features = next(iter(ds))["input_features"] | ||
sample_decoder_ids = [processor.tokenizer.prefix_tokens] | ||
sample_input = { | ||
"input_features": torch.tensor(sample_features).to(model.device), | ||
"decoder_input_ids": torch.tensor(sample_decoder_ids).to(model.device), | ||
} | ||
|
||
output = model.generate(**sample_input, language="en") | ||
print(processor.batch_decode(output, skip_special_tokens=True)) | ||
print("==========================================\n\n") | ||
# that's where you have a lot of windows in the south no actually that's passive solar | ||
# and passive solar is something that was developed and designed in the 1960s and 70s | ||
# and it was a great thing for what it was at the time but it's not a passive house | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
processor.save_pretrained(SAVE_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.