diff --git a/src/sparsetensors/quantization/quant_config.py b/src/sparsetensors/quantization/quant_config.py index b2b40b06..c70e7c45 100644 --- a/src/sparsetensors/quantization/quant_config.py +++ b/src/sparsetensors/quantization/quant_config.py @@ -18,6 +18,7 @@ from pydantic import BaseModel, Field from sparsetensors.quantization.quant_scheme import QuantizationScheme from sparsetensors.quantization.utils import ( + calculate_compression_ratio, is_module_quantized, iter_named_leaf_modules, module_type, @@ -100,7 +101,11 @@ class QuantizationConfig(BaseModel): @staticmethod def from_pretrained(model: Module) -> "QuantizationConfig": """ - TODO: fill in docstrings + Converts a model into its associated QuantizationConfig based on the + QuantizationScheme attached to each quanitzed module + + :param model: model to calculate quantization scheme of + :return: filled out QuantizationScheme for the input model """ quant_scheme_to_layers = [] quantization_status = None @@ -125,6 +130,8 @@ def from_pretrained(model: Module) -> "QuantizationConfig": if not match_found: quant_scheme_to_layers.append(scheme) + # clean up ignore list, we can leave out layers types if none of the + # instances are quantized consolidated_ignore = [] for layer_type, ignore_names in ignore.items(): if layer_type in quantization_type_names: @@ -138,8 +145,10 @@ def from_pretrained(model: Module) -> "QuantizationConfig": group_name = "group_" + str(idx) config_groups[group_name] = scheme + compression_ratio = calculate_compression_ratio(model) return QuantizationConfig( config_groups=config_groups, quantization_status=quantization_status, + global_compression_ratio=compression_ratio, ignore=consolidated_ignore, ) diff --git a/src/sparsetensors/quantization/utils/helpers.py b/src/sparsetensors/quantization/utils/helpers.py index 6156f364..9f2bfffd 100644 --- a/src/sparsetensors/quantization/utils/helpers.py +++ b/src/sparsetensors/quantization/utils/helpers.py @@ -14,13 +14,27 @@ from typing import Tuple +import torch from torch.nn import Module +from tqdm import tqdm -__all__ = ["is_module_quantized", "iter_named_leaf_modules", "module_type"] +__all__ = [ + "is_module_quantized", + "iter_named_leaf_modules", + "module_type", + "calculate_compression_ratio", +] def is_module_quantized(module: Module) -> bool: + """ + Check if a module is quantized, based on the existence of a non-empty quantization + scheme + + :param module: pytorch module to check + :return: True if module is quantized, False otherwise + """ if not hasattr(module, "quantization_scheme"): return False @@ -37,6 +51,12 @@ def is_module_quantized(module: Module) -> bool: def module_type(module: Module) -> str: + """ + Gets a string representation of a module type + + :module: pytorch module to get type of + :return: module type as a string + """ return type(module).__name__ @@ -46,3 +66,35 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: for name, submodule in model.named_modules(): if len(list(submodule.children())) == 0: yield name, submodule + + +def calculate_compression_ratio(model: Module) -> float: + """ + Calculates the quantization compression ratio of a pytorch model, based on the + number of bits needed to represent the total weights in compressed form. Does not + take into account activation quantizatons. + + :param model: pytorch module to calculate compression ratio for + :return: compression ratio of the whole model + """ + total_compressed = 0.0 + total_uncompressed = 0.0 + for name, submodule in tqdm( + iter_named_leaf_modules(model), + desc="Calculating quantization compression ratio", + ): + for parameter in model.parameters(): + try: + uncompressed_bits = torch.finfo(parameter.dtype).bits + except TypeError: + uncompressed_bits = torch.iinfo(parameter.dtype).bits + compressed_bits = uncompressed_bits + if is_module_quantized(submodule): + compressed_bits = submodule.quantization_scheme.weights.num_bits + else: + print(name) + num_weights = parameter.numel() + total_compressed += compressed_bits * num_weights + total_uncompressed += uncompressed_bits * num_weights + + return total_uncompressed / total_compressed diff --git a/tests/quantization/lifecycle/test_apply.py b/tests/quantization/lifecycle/test_apply.py index 02f46890..eeb29a41 100644 --- a/tests/quantization/lifecycle/test_apply.py +++ b/tests/quantization/lifecycle/test_apply.py @@ -77,6 +77,8 @@ def test_serialize_config_tinyllama(): assert serialized_config.format == "fakequant" assert serialized_config.quant_method == "sparseml" assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.global_compression_ratio > 1.0 + assert serialized_config.global_compression_ratio < 8.0 def _test_layer_quantization_status(module, inputs: bool, weights: bool): diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 40a82cd7..92b68ab7 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -31,7 +31,7 @@ def test_basic_config(): assert config.format == "fakequant" assert config.quantization_status == QuantizationStatus.INITIALIZED assert config.global_compression_ratio is None - assert config.ignore is None + assert isinstance(config.ignore, list) and len(config.ignore) == 0 def test_full_config():