Skip to content

Commit

Permalink
Merge pull request #26 from neuralmagic/sa/debug_ppl
Browse files Browse the repository at this point in the history
Quantization Examples and Correctness Fixes
  • Loading branch information
Sara Adkins authored Apr 18, 2024
2 parents fd9545d + 6140bd2 commit ed3e2c7
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 32 deletions.
85 changes: 85 additions & 0 deletions examples/llama_1.1b/ex_config_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tqdm import tqdm
from torch.utils.data import RandomSampler
from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
QuantizationConfig,
QuantizationStatus,
)
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader
from sparseml.pytorch.utils import tensors_to_device
import torch

config_file = "example_quant_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
split = "train"
num_calibration_samples = 512
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_new_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
model.eval() # no grad or updates needed for base model
config = QuantizationConfig.parse_file(config_file)

# set status to calibration
config.quantization_status = QuantizationStatus.CALIBRATION

# initialize quantization
apply_quantization_config(model, config)

# create dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset)
)

# run calibration
for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
sample = tensors_to_device(sample, "cuda:0")
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)

# this functionality will move but for now we need to get the save override from
# SparseML in order to save the config
from sparseml.transformers.compression import modify_save_pretrained
modify_save_pretrained(model)
model.save_pretrained(output_dir)
37 changes: 27 additions & 10 deletions examples/llama_1.1b/ex_sparseml_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,45 @@
# limitations under the License.

from sparseml.transformers import oneshot, SparseAutoModelForCausalLM
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoTokenizer
import torch

dataset_name = "open_platypus"
overwrite_output_dir = True
splits = {"calibration": "train"}
seed = 42
output_dir = "./llama_1.1b_quant_mod_only"
num_calibration_samples = 1024
recipe = "example_quant_recipe.yaml"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
split = "train"
num_calibration_samples = 512
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_old_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0")
model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)

oneshot(
model=model_name,
dataset=dataset_name,
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
splits = splits,
overwrite_output_dir=True,
max_seq_length = max_seq_length,
seed=seed,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
pad_to_max_length=pad_to_max_length
Expand Down
13 changes: 3 additions & 10 deletions examples/llama_1.1b/example_quant_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": null,
"config_groups": {
"group_1": {
Expand All @@ -14,7 +13,7 @@
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"symmetric": false,
"strategy": "tensor"
},
"targets": ["Linear"]
Expand All @@ -23,17 +22,11 @@
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": false,
"symmetric": true,
"strategy": "tensor"
},
"input_activations": null,
"targets": ["Embedding"]
}
},
"ignore": [
"LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation",
"model.layers.1.mlp.down_proj", "MatMulLeftInput_QK", "MatMulRightInput_QK",
"MatMulOutput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV",
"MatMulOutput_PV"
]
"ignore": ["model.layers.0.mlp.down_proj"]
}
32 changes: 32 additions & 0 deletions examples/llama_1.1b/example_quant_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
test_stage:
quant_modifiers:
QuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLU
- MatMulLeftInput_QK
- MatMulRightInput_QK
- MatMulOutput_QK
- MatMulLeftInput_PV
- MatMulRightInput_PV
- MatMulOutput_PV
scheme_overrides:
Linear:
weights:
num_bits: 8
symmetric: true
strategy: "tensor"
input_activations:
num_bits: 8
symmetric: false
strategy: "tensor"
output_activations: null
Embedding:
weights:
num_bits: 8
symmetric: true
strategy: "tensor"
input_activations: null
output_activations: null
10 changes: 6 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
0,
q_min,
q_max,
)

Expand All @@ -56,9 +57,11 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1, device=x.device)
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, max_q)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)


Expand Down Expand Up @@ -114,7 +117,6 @@ def _maybe_calibrate_or_quantize(

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def calculate_qparams(
:return: tuple of the calculated scale(s) and zero point(s)
"""
bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
Expand All @@ -46,6 +47,6 @@ def calculate_qparams(
# scales from a 0 range should be set to 1
scales[observed_range == 0] = 1

zero_points = ((0 - min_vals) / scales).to(torch.int8)
zero_points = ((0 - min_vals) / scales + bit_min).to(torch.int8)

return scales, zero_points
12 changes: 5 additions & 7 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class MinMaxObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on the latest observed value
zero point based on the overall min and max value
"""

def __init__(self, quantization_args: QuantizationArgs):
Expand All @@ -43,15 +43,14 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor

min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])

# update running average
# update global min and max
if self.counter > 0:
self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1)
self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1)
self.min_val = torch.min(min_val, self.min_val)
self.max_val = torch.max(max_val, self.max_val)
else:
self.min_val = min_val
self.max_val = max_val
Expand All @@ -61,5 +60,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))

self.counter += 1

return calculate_qparams(min_val, max_val, self.quantization_args)

0 comments on commit ed3e2c7

Please sign in to comment.