diff --git a/src/sparsetensors/quantization/lifecycle/apply.py b/src/sparsetensors/quantization/lifecycle/apply.py index 7ad46f1f..4c78568d 100644 --- a/src/sparsetensors/quantization/lifecycle/apply.py +++ b/src/sparsetensors/quantization/lifecycle/apply.py @@ -14,7 +14,7 @@ import re from collections import OrderedDict -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization @@ -25,6 +25,7 @@ QuantizationConfig, QuantizationStatus, ) +from sparsetensors.quantization.utils import iter_named_leaf_modules from torch.nn import Module @@ -76,14 +77,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(freeze_module_quantization) -def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: - # yields modules that do not have any submodules - # TODO: potentially expand to add list of allowed submodules such as observers - for name, submodule in model.named_modules(): - if len(list(submodule.children())) == 0: - yield name, submodule - - def _find_first_name_or_class_match( name: str, module: Module, diff --git a/src/sparsetensors/quantization/quant_config.py b/src/sparsetensors/quantization/quant_config.py index 813c7197..c70e7c45 100644 --- a/src/sparsetensors/quantization/quant_config.py +++ b/src/sparsetensors/quantization/quant_config.py @@ -15,8 +15,15 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from sparsetensors.quantization.quant_scheme import QuantizationScheme +from sparsetensors.quantization.utils import ( + calculate_compression_ratio, + is_module_quantized, + iter_named_leaf_modules, + module_type, +) +from torch.nn import Module __all__ = [ @@ -89,4 +96,59 @@ class QuantizationConfig(BaseModel): format: str = "fakequant" quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None - ignore: Optional[List[str]] = None + ignore: Optional[List[str]] = Field(default_factory=list) + + @staticmethod + def from_pretrained(model: Module) -> "QuantizationConfig": + """ + Converts a model into its associated QuantizationConfig based on the + QuantizationScheme attached to each quanitzed module + + :param model: model to calculate quantization scheme of + :return: filled out QuantizationScheme for the input model + """ + quant_scheme_to_layers = [] + quantization_status = None + ignore = {} + quantization_type_names = set() + for name, submodule in iter_named_leaf_modules(model): + layer_type = module_type(submodule) + if not is_module_quantized(submodule): + if layer_type not in ignore: + ignore[layer_type] = [] + ignore[layer_type].append(name) + else: + quantization_status = submodule.quantization_status + scheme = submodule.quantization_scheme + quantization_type_names.add(layer_type) + + match_found = False + for existing_scheme in quant_scheme_to_layers: + if scheme == existing_scheme: + match_found = True + break + if not match_found: + quant_scheme_to_layers.append(scheme) + + # clean up ignore list, we can leave out layers types if none of the + # instances are quantized + consolidated_ignore = [] + for layer_type, ignore_names in ignore.items(): + if layer_type in quantization_type_names: + # specific layers of a quantized type are ignored + consolidated_ignore += ignore_names + # else we leave it off the ignore list, doesn't fall under any of the + # existing quantization schemes so it won't be quantized + + config_groups = {} + for idx, scheme in enumerate(quant_scheme_to_layers): + group_name = "group_" + str(idx) + config_groups[group_name] = scheme + + compression_ratio = calculate_compression_ratio(model) + return QuantizationConfig( + config_groups=config_groups, + quantization_status=quantization_status, + global_compression_ratio=compression_ratio, + ignore=consolidated_ignore, + ) diff --git a/src/sparsetensors/quantization/utils/__init__.py b/src/sparsetensors/quantization/utils/__init__.py new file mode 100644 index 00000000..a91f9e5d --- /dev/null +++ b/src/sparsetensors/quantization/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .helpers import * diff --git a/src/sparsetensors/quantization/utils/helpers.py b/src/sparsetensors/quantization/utils/helpers.py new file mode 100644 index 00000000..52bebf58 --- /dev/null +++ b/src/sparsetensors/quantization/utils/helpers.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch.nn import Module +from tqdm import tqdm + + +__all__ = [ + "is_module_quantized", + "is_model_quantized", + "iter_named_leaf_modules", + "module_type", + "calculate_compression_ratio", +] + + +def is_module_quantized(module: Module) -> bool: + """ + Check if a module is quantized, based on the existence of a non-empty quantization + scheme + + :param module: pytorch module to check + :return: True if module is quantized, False otherwise + """ + if not hasattr(module, "quantization_scheme"): + return False + + if module.quantization_scheme.weights is not None: + return True + + if module.quantization_scheme.input_activations is not None: + return True + + if module.quantization_scheme.output_activations is not None: + return True + + return False + + +def is_model_quantized(model: Module) -> bool: + """ + Check if any modules in a model are quantized, based on the existence of a non-empty + quantization scheme in at least one module + + :param model: pytorch model + :return: True if model is quantized, False otherwise + """ + + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + return True + + return False + + +def module_type(module: Module) -> str: + """ + Gets a string representation of a module type + + :module: pytorch module to get type of + :return: module type as a string + """ + return type(module).__name__ + + +def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: + # yields modules that do not have any submodules + # TODO: potentially expand to add list of allowed submodules such as observers + for name, submodule in model.named_modules(): + if len(list(submodule.children())) == 0: + yield name, submodule + + +def calculate_compression_ratio(model: Module) -> float: + """ + Calculates the quantization compression ratio of a pytorch model, based on the + number of bits needed to represent the total weights in compressed form. Does not + take into account activation quantizatons. + + :param model: pytorch module to calculate compression ratio for + :return: compression ratio of the whole model + """ + total_compressed = 0.0 + total_uncompressed = 0.0 + for name, submodule in tqdm( + iter_named_leaf_modules(model), + desc="Calculating quantization compression ratio", + ): + for parameter in model.parameters(): + try: + uncompressed_bits = torch.finfo(parameter.dtype).bits + except TypeError: + uncompressed_bits = torch.iinfo(parameter.dtype).bits + compressed_bits = uncompressed_bits + if is_module_quantized(submodule): + compressed_bits = submodule.quantization_scheme.weights.num_bits + else: + print(name) + num_weights = parameter.numel() + total_compressed += compressed_bits * num_weights + total_uncompressed += uncompressed_bits * num_weights + + return total_uncompressed / total_compressed diff --git a/tests/quantization/lifecycle/test_apply.py b/tests/quantization/lifecycle/test_apply.py index 46351cd8..eeb29a41 100644 --- a/tests/quantization/lifecycle/test_apply.py +++ b/tests/quantization/lifecycle/test_apply.py @@ -14,7 +14,10 @@ from sparsetensors.quantization.lifecycle import apply_quantization_config -from sparsetensors.quantization.quant_config import QuantizationConfig +from sparsetensors.quantization.quant_config import ( + QuantizationConfig, + QuantizationStatus, +) from transformers import AutoModelForCausalLM @@ -33,7 +36,9 @@ def test_apply_quantization_config_tinyllama(): num_linears = 0 num_embeddings = 0 num_rotary_embeddings = 0 - for module in model.modules(): + for name, module in model.named_modules(): + if name in quant_config.ignore: + continue module_type = module.__class__.__name__ if module_type == "Linear": num_linears += 1 @@ -46,11 +51,36 @@ def test_apply_quantization_config_tinyllama(): _test_layer_quantization_status(module, inputs=False, weights=False) # sanity check correct number of layers targeted - assert num_linears == 155 + assert num_linears == 154 # 155 Linear layers - 1 that gets ignored assert num_embeddings == 1 assert num_rotary_embeddings == 22 +def test_serialize_config_tinyllama(): + quant_config = get_sample_tinyllama_quant_config() + model = get_tinyllama_model() + + # check that model is not already quantized + for module in model.modules(): + _test_layer_quantization_status(module, inputs=False, weights=False) + + # apply quant config to model + apply_quantization_config(model, quant_config) + + serialized_config = QuantizationConfig.from_pretrained(model) + assert len(serialized_config.config_groups) == 2 + assert serialized_config.config_groups["group_0"].targets == ["Embedding"] + assert serialized_config.config_groups["group_0"].input_activations is None + assert serialized_config.config_groups["group_1"].targets == ["Linear"] + assert serialized_config.config_groups["group_1"].input_activations is not None + assert serialized_config.quantization_status == QuantizationStatus.FROZEN + assert serialized_config.format == "fakequant" + assert serialized_config.quant_method == "sparseml" + assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.global_compression_ratio > 1.0 + assert serialized_config.global_compression_ratio < 8.0 + + def _test_layer_quantization_status(module, inputs: bool, weights: bool): # check if quantization is applied at all (true if inputs or weights targeted) quantized = inputs or weights @@ -105,9 +135,6 @@ def get_sample_tinyllama_quant_config(): "targets": ["Embedding"], }, }, - "ignore": ["LlamaRotaryEmbedding"], + "ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"], } return QuantizationConfig.parse_obj(config_dict) - - -test_apply_quantization_config_tinyllama() diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 40a82cd7..92b68ab7 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -31,7 +31,7 @@ def test_basic_config(): assert config.format == "fakequant" assert config.quantization_status == QuantizationStatus.INITIALIZED assert config.global_compression_ratio is None - assert config.ignore is None + assert isinstance(config.ignore, list) and len(config.ignore) == 0 def test_full_config():