Skip to content

Commit

Permalink
global compression ratio and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 16, 2024
1 parent 845bfb9 commit 1a7984c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import BaseModel, Field
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
iter_named_leaf_modules,
module_type,
Expand Down Expand Up @@ -100,7 +101,11 @@ class QuantizationConfig(BaseModel):
@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
TODO: fill in docstrings
Converts a model into its associated QuantizationConfig based on the
QuantizationScheme attached to each quanitzed module
:param model: model to calculate quantization scheme of
:return: filled out QuantizationScheme for the input model
"""
quant_scheme_to_layers = []
quantization_status = None
Expand All @@ -125,6 +130,8 @@ def from_pretrained(model: Module) -> "QuantizationConfig":
if not match_found:
quant_scheme_to_layers.append(scheme)

# clean up ignore list, we can leave out layers types if none of the
# instances are quantized
consolidated_ignore = []
for layer_type, ignore_names in ignore.items():
if layer_type in quantization_type_names:
Expand All @@ -138,8 +145,10 @@ def from_pretrained(model: Module) -> "QuantizationConfig":
group_name = "group_" + str(idx)
config_groups[group_name] = scheme

compression_ratio = calculate_compression_ratio(model)
return QuantizationConfig(
config_groups=config_groups,
quantization_status=quantization_status,
global_compression_ratio=compression_ratio,
ignore=consolidated_ignore,
)
54 changes: 53 additions & 1 deletion src/sparsetensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,27 @@

from typing import Tuple

import torch
from torch.nn import Module
from tqdm import tqdm


__all__ = ["is_module_quantized", "iter_named_leaf_modules", "module_type"]
__all__ = [
"is_module_quantized",
"iter_named_leaf_modules",
"module_type",
"calculate_compression_ratio",
]


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme
:param module: pytorch module to check
:return: True if module is quantized, False otherwise
"""
if not hasattr(module, "quantization_scheme"):
return False

Expand All @@ -37,6 +51,12 @@ def is_module_quantized(module: Module) -> bool:


def module_type(module: Module) -> str:
"""
Gets a string representation of a module type
:module: pytorch module to get type of
:return: module type as a string
"""
return type(module).__name__


Expand All @@ -46,3 +66,35 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def calculate_compression_ratio(model: Module) -> float:
"""
Calculates the quantization compression ratio of a pytorch model, based on the
number of bits needed to represent the total weights in compressed form. Does not
take into account activation quantizatons.
:param model: pytorch module to calculate compression ratio for
:return: compression ratio of the whole model
"""
total_compressed = 0.0
total_uncompressed = 0.0
for name, submodule in tqdm(
iter_named_leaf_modules(model),
desc="Calculating quantization compression ratio",
):
for parameter in model.parameters():
try:
uncompressed_bits = torch.finfo(parameter.dtype).bits
except TypeError:
uncompressed_bits = torch.iinfo(parameter.dtype).bits
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits
else:
print(name)
num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights

return total_uncompressed / total_compressed
2 changes: 2 additions & 0 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_serialize_config_tinyllama():
assert serialized_config.format == "fakequant"
assert serialized_config.quant_method == "sparseml"
assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"]
assert serialized_config.global_compression_ratio > 1.0
assert serialized_config.global_compression_ratio < 8.0


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_basic_config():
assert config.format == "fakequant"
assert config.quantization_status == QuantizationStatus.INITIALIZED
assert config.global_compression_ratio is None
assert config.ignore is None
assert isinstance(config.ignore, list) and len(config.ignore) == 0


def test_full_config():
Expand Down

0 comments on commit 1a7984c

Please sign in to comment.