Skip to content

Commit

Permalink
Merge branch 'main' into channelwise-quant
Browse files Browse the repository at this point in the history
horheynm authored Apr 19, 2024
2 parents f6769c3 + 06200fc commit e2af2b5
Showing 18 changed files with 182 additions and 46 deletions.
85 changes: 85 additions & 0 deletions examples/llama_1.1b/ex_config_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tqdm import tqdm
from torch.utils.data import RandomSampler
from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
QuantizationConfig,
QuantizationStatus,
)
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader
from sparseml.pytorch.utils import tensors_to_device
import torch

config_file = "example_quant_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
split = "train"
num_calibration_samples = 512
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_new_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
model.eval() # no grad or updates needed for base model
config = QuantizationConfig.parse_file(config_file)

# set status to calibration
config.quantization_status = QuantizationStatus.CALIBRATION

# initialize quantization
apply_quantization_config(model, config)

# create dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset)
)

# run calibration
for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
sample = tensors_to_device(sample, "cuda:0")
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)

# this functionality will move but for now we need to get the save override from
# SparseML in order to save the config
from sparseml.transformers.compression import modify_save_pretrained
modify_save_pretrained(model)
model.save_pretrained(output_dir)
37 changes: 27 additions & 10 deletions examples/llama_1.1b/ex_sparseml_quantization.py
Original file line number Diff line number Diff line change
@@ -13,28 +13,45 @@
# limitations under the License.

from sparseml.transformers import oneshot, SparseAutoModelForCausalLM
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoTokenizer
import torch

dataset_name = "open_platypus"
overwrite_output_dir = True
splits = {"calibration": "train"}
seed = 42
output_dir = "./llama_1.1b_quant_mod_only"
num_calibration_samples = 1024
recipe = "example_quant_recipe.yaml"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
split = "train"
num_calibration_samples = 512
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_old_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0")
model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)

oneshot(
model=model_name,
dataset=dataset_name,
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
splits = splits,
overwrite_output_dir=True,
max_seq_length = max_seq_length,
seed=seed,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
pad_to_max_length=pad_to_max_length
13 changes: 3 additions & 10 deletions examples/llama_1.1b/example_quant_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": null,
"config_groups": {
"group_1": {
@@ -14,7 +13,7 @@
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"symmetric": false,
"strategy": "tensor"
},
"targets": ["Linear"]
@@ -23,17 +22,11 @@
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": false,
"symmetric": true,
"strategy": "tensor"
},
"input_activations": null,
"targets": ["Embedding"]
}
},
"ignore": [
"LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation",
"model.layers.1.mlp.down_proj", "MatMulLeftInput_QK", "MatMulRightInput_QK",
"MatMulOutput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV",
"MatMulOutput_PV"
]
"ignore": ["model.layers.0.mlp.down_proj"]
}
32 changes: 32 additions & 0 deletions examples/llama_1.1b/example_quant_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
test_stage:
quant_modifiers:
QuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLU
- MatMulLeftInput_QK
- MatMulRightInput_QK
- MatMulOutput_QK
- MatMulLeftInput_PV
- MatMulRightInput_PV
- MatMulOutput_PV
scheme_overrides:
Linear:
weights:
num_bits: 8
symmetric: true
strategy: "tensor"
input_activations:
num_bits: 8
symmetric: false
strategy: "tensor"
output_activations: null
Embedding:
weights:
num_bits: 8
symmetric: true
strategy: "tensor"
input_activations: null
output_activations: null
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

def _setup_packages() -> List:
return find_packages(
"src", include=["compressed-tensors", "compressed-tensors.*"], exclude=["*.__pycache__.*"]
"src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"]
)

def _setup_install_requires() -> List:
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,9 @@

import numpy
import torch
from safetensors import safe_open
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm

2 changes: 1 addition & 1 deletion src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@

from typing import Optional

from pydantic import BaseModel
from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["CompressionConfig"]
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,9 @@
from collections import OrderedDict
from typing import Iterable, Optional

from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration
from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
)
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
10 changes: 6 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
@@ -29,13 +29,14 @@ def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
0,
q_min,
q_max,
)

@@ -56,9 +57,11 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1, device=x.device)
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, max_q)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)


@@ -114,7 +117,6 @@ def _maybe_calibrate_or_quantize(

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@
from typing import Optional

import torch
from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def calculate_qparams(
:return: tuple of the calculated scale(s) and zero point(s)
"""
bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
@@ -46,6 +47,6 @@ def calculate_qparams(
# scales from a 0 range should be set to 1
scales[observed_range == 0] = 1

zero_points = ((0 - min_vals) / scales).to(torch.int8)
zero_points = torch.round(((0.0 - min_vals) / scales + bit_min)).to(torch.int8)

return scales, zero_points
22 changes: 12 additions & 10 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
class MinMaxObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on the latest observed value
zero point based on the overall min and max value
"""

def __init__(self, quantization_args: QuantizationArgs):
@@ -56,12 +56,14 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:

# update running average
if self.counter > 0:
self.min_vals = (self.min_vals * self.counter + min_vals) / (
self.counter + 1
)
self.max_vals = (self.max_vals * self.counter + max_vals) / (
self.counter + 1
)
# self.min_vals = (self.min_vals * self.counter + min_vals) / (
# self.counter + 1
# )
# self.max_vals = (self.max_vals * self.counter + max_vals) / (
# self.counter + 1
# )
self.min_vals = torch.min(min_vals, self.min_vals)
self.max_vals = torch.max(max_val, self.max_vals)
else:
self.min_vals = min_vals
self.max_vals = max_vals
@@ -76,10 +78,10 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])

# update running average
# update global min and max
if self.counter > 0:
self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1)
self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1)
self.min_val = torch.min(min_val, self.min_val)
self.max_val = torch.max(max_val, self.max_val)
else:
self.min_val = min_val
self.max_val = max_val
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,14 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel, Field
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
iter_named_leaf_modules,
module_type,
)
from pydantic import BaseModel, Field
from torch.nn import Module


2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@

from typing import List, Optional

from pydantic import BaseModel
from compressed_tensors.quantization.quant_args import QuantizationArgs
from pydantic import BaseModel


__all__ = ["QuantizationScheme"]
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_args.py
Original file line number Diff line number Diff line change
@@ -13,12 +13,12 @@
# limitations under the License.

import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from pydantic import ValidationError


def test_defaults():
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
@@ -14,12 +14,12 @@


import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
)
from pydantic import ValidationError


def test_basic_config():
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_scheme.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@
# limitations under the License.

import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from pydantic import ValidationError


def test_basic_scheme():
2 changes: 1 addition & 1 deletion tests/test_bitmask.py
Original file line number Diff line number Diff line change
@@ -17,8 +17,8 @@

import pytest
import torch
from safetensors.torch import save_file
from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor
from safetensors.torch import save_file


@pytest.mark.parametrize(

0 comments on commit e2af2b5

Please sign in to comment.