Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 18, 2024
1 parent a0af21a commit 20986a6
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/llama_1.1b/ex_config_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from tqdm import tqdm

from torch.utils.data import RandomSampler
from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
Expand All @@ -24,7 +24,7 @@
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader

from sparseml.pytorch.utils import tensors_to_device

config_file = "example_quant_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Expand Down Expand Up @@ -62,16 +62,22 @@
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator()
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset)
)

# run calibration
for idx, sample in tqdm(enumerate(data_loader)):
for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
sample = tensors_to_device(sample, "cuda:0")
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)

# this functionality will move but for now we need to get the save override from
# SparseML in order to save the config
from sparseml.transformers.compression import modify_save_pretrained
modify_save_pretrained(model)
model.save_pretrained(output_dir)

0 comments on commit 20986a6

Please sign in to comment.