Skip to content

Commit

Permalink
Serialize Config from Model (#7)
Browse files Browse the repository at this point in the history
* Apply quantization config implementation

* add TODO

* integrate full lifecycle support, QuantizationStatus updates, add tinyllama test

* fix comment

* initial implementation

* add unit test

* cleanup is_quantized

* clean up targets and ignore lists

* global compression ratio and docstrings

* make sure scale/zp on correct device

* helper for model quantization
  • Loading branch information
Sara Adkins authored Apr 16, 2024
1 parent 514e4db commit edc35a1
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 19 deletions.
11 changes: 2 additions & 9 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from collections import OrderedDict
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional

from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
Expand All @@ -25,6 +25,7 @@
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.utils import iter_named_leaf_modules
from torch.nn import Module


Expand Down Expand Up @@ -76,14 +77,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(freeze_module_quantization)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def _find_first_name_or_class_match(
name: str,
module: Module,
Expand Down
66 changes: 64 additions & 2 deletions src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
iter_named_leaf_modules,
module_type,
)
from torch.nn import Module


__all__ = [
Expand Down Expand Up @@ -89,4 +96,59 @@ class QuantizationConfig(BaseModel):
format: str = "fakequant"
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = None
ignore: Optional[List[str]] = Field(default_factory=list)

@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
Converts a model into its associated QuantizationConfig based on the
QuantizationScheme attached to each quanitzed module
:param model: model to calculate quantization scheme of
:return: filled out QuantizationScheme for the input model
"""
quant_scheme_to_layers = []
quantization_status = None
ignore = {}
quantization_type_names = set()
for name, submodule in iter_named_leaf_modules(model):
layer_type = module_type(submodule)
if not is_module_quantized(submodule):
if layer_type not in ignore:
ignore[layer_type] = []
ignore[layer_type].append(name)
else:
quantization_status = submodule.quantization_status
scheme = submodule.quantization_scheme
quantization_type_names.add(layer_type)

match_found = False
for existing_scheme in quant_scheme_to_layers:
if scheme == existing_scheme:
match_found = True
break
if not match_found:
quant_scheme_to_layers.append(scheme)

# clean up ignore list, we can leave out layers types if none of the
# instances are quantized
consolidated_ignore = []
for layer_type, ignore_names in ignore.items():
if layer_type in quantization_type_names:
# specific layers of a quantized type are ignored
consolidated_ignore += ignore_names
# else we leave it off the ignore list, doesn't fall under any of the
# existing quantization schemes so it won't be quantized

config_groups = {}
for idx, scheme in enumerate(quant_scheme_to_layers):
group_name = "group_" + str(idx)
config_groups[group_name] = scheme

compression_ratio = calculate_compression_ratio(model)
return QuantizationConfig(
config_groups=config_groups,
quantization_status=quantization_status,
global_compression_ratio=compression_ratio,
ignore=consolidated_ignore,
)
16 changes: 16 additions & 0 deletions src/sparsetensors/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from .helpers import *
117 changes: 117 additions & 0 deletions src/sparsetensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
from torch.nn import Module
from tqdm import tqdm


__all__ = [
"is_module_quantized",
"is_model_quantized",
"iter_named_leaf_modules",
"module_type",
"calculate_compression_ratio",
]


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme
:param module: pytorch module to check
:return: True if module is quantized, False otherwise
"""
if not hasattr(module, "quantization_scheme"):
return False

if module.quantization_scheme.weights is not None:
return True

if module.quantization_scheme.input_activations is not None:
return True

if module.quantization_scheme.output_activations is not None:
return True

return False


def is_model_quantized(model: Module) -> bool:
"""
Check if any modules in a model are quantized, based on the existence of a non-empty
quantization scheme in at least one module
:param model: pytorch model
:return: True if model is quantized, False otherwise
"""

for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
return True

return False


def module_type(module: Module) -> str:
"""
Gets a string representation of a module type
:module: pytorch module to get type of
:return: module type as a string
"""
return type(module).__name__


def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def calculate_compression_ratio(model: Module) -> float:
"""
Calculates the quantization compression ratio of a pytorch model, based on the
number of bits needed to represent the total weights in compressed form. Does not
take into account activation quantizatons.
:param model: pytorch module to calculate compression ratio for
:return: compression ratio of the whole model
"""
total_compressed = 0.0
total_uncompressed = 0.0
for name, submodule in tqdm(
iter_named_leaf_modules(model),
desc="Calculating quantization compression ratio",
):
for parameter in model.parameters():
try:
uncompressed_bits = torch.finfo(parameter.dtype).bits
except TypeError:
uncompressed_bits = torch.iinfo(parameter.dtype).bits
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits
else:
print(name)
num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights

return total_uncompressed / total_compressed
41 changes: 34 additions & 7 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@


from sparsetensors.quantization.lifecycle import apply_quantization_config
from sparsetensors.quantization.quant_config import QuantizationConfig
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from transformers import AutoModelForCausalLM


Expand All @@ -33,7 +36,9 @@ def test_apply_quantization_config_tinyllama():
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for module in model.modules():
for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
Expand All @@ -46,11 +51,36 @@ def test_apply_quantization_config_tinyllama():
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 155
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
assert num_embeddings == 1
assert num_rotary_embeddings == 22


def test_serialize_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

serialized_config = QuantizationConfig.from_pretrained(model)
assert len(serialized_config.config_groups) == 2
assert serialized_config.config_groups["group_0"].targets == ["Embedding"]
assert serialized_config.config_groups["group_0"].input_activations is None
assert serialized_config.config_groups["group_1"].targets == ["Linear"]
assert serialized_config.config_groups["group_1"].input_activations is not None
assert serialized_config.quantization_status == QuantizationStatus.FROZEN
assert serialized_config.format == "fakequant"
assert serialized_config.quant_method == "sparseml"
assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"]
assert serialized_config.global_compression_ratio > 1.0
assert serialized_config.global_compression_ratio < 8.0


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
Expand Down Expand Up @@ -105,9 +135,6 @@ def get_sample_tinyllama_quant_config():
"targets": ["Embedding"],
},
},
"ignore": ["LlamaRotaryEmbedding"],
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)


test_apply_quantization_config_tinyllama()
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_basic_config():
assert config.format == "fakequant"
assert config.quantization_status == QuantizationStatus.INITIALIZED
assert config.global_compression_ratio is None
assert config.ignore is None
assert isinstance(config.ignore, list) and len(config.ignore) == 0


def test_full_config():
Expand Down

0 comments on commit edc35a1

Please sign in to comment.