Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize Config from Model #7

Merged
merged 14 commits into from
Apr 16, 2024
11 changes: 2 additions & 9 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from collections import OrderedDict
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional

from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
Expand All @@ -25,6 +25,7 @@
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.utils import iter_named_leaf_modules
from torch.nn import Module


Expand Down Expand Up @@ -76,14 +77,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(freeze_module_quantization)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def _find_first_name_or_class_match(
name: str,
module: Module,
Expand Down
66 changes: 64 additions & 2 deletions src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
iter_named_leaf_modules,
module_type,
)
from torch.nn import Module


__all__ = [
Expand Down Expand Up @@ -89,4 +96,59 @@ class QuantizationConfig(BaseModel):
format: str = "fakequant"
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = None
ignore: Optional[List[str]] = Field(default_factory=list)

@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
Converts a model into its associated QuantizationConfig based on the
QuantizationScheme attached to each quanitzed module

:param model: model to calculate quantization scheme of
:return: filled out QuantizationScheme for the input model
"""
quant_scheme_to_layers = []
quantization_status = None
ignore = {}
quantization_type_names = set()
for name, submodule in iter_named_leaf_modules(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See TODO comment about allowing for exceptions in leaf nodes for observers. This will be relevant for non frozen quantized models

layer_type = module_type(submodule)
if not is_module_quantized(submodule):
if layer_type not in ignore:
ignore[layer_type] = []
ignore[layer_type].append(name)
else:
quantization_status = submodule.quantization_status
scheme = submodule.quantization_scheme
quantization_type_names.add(layer_type)

match_found = False
for existing_scheme in quant_scheme_to_layers:
if scheme == existing_scheme:
match_found = True
break
if not match_found:
quant_scheme_to_layers.append(scheme)

# clean up ignore list, we can leave out layers types if none of the
# instances are quantized
consolidated_ignore = []
for layer_type, ignore_names in ignore.items():
if layer_type in quantization_type_names:
# specific layers of a quantized type are ignored
consolidated_ignore += ignore_names
# else we leave it off the ignore list, doesn't fall under any of the
# existing quantization schemes so it won't be quantized

config_groups = {}
for idx, scheme in enumerate(quant_scheme_to_layers):
group_name = "group_" + str(idx)
config_groups[group_name] = scheme

compression_ratio = calculate_compression_ratio(model)
return QuantizationConfig(
config_groups=config_groups,
quantization_status=quantization_status,
global_compression_ratio=compression_ratio,
ignore=consolidated_ignore,
)
16 changes: 16 additions & 0 deletions src/sparsetensors/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from .helpers import *
117 changes: 117 additions & 0 deletions src/sparsetensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
from torch.nn import Module
from tqdm import tqdm


__all__ = [
"is_module_quantized",
"is_model_quantized",
"iter_named_leaf_modules",
"module_type",
"calculate_compression_ratio",
]


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme

:param module: pytorch module to check
:return: True if module is quantized, False otherwise
"""
if not hasattr(module, "quantization_scheme"):
return False

if module.quantization_scheme.weights is not None:
return True

if module.quantization_scheme.input_activations is not None:
return True

if module.quantization_scheme.output_activations is not None:
return True

return False


def is_model_quantized(model: Module) -> bool:
"""
Check if any modules in a model are quantized, based on the existence of a non-empty
quantization scheme in at least one module

:param model: pytorch model
:return: True if model is quantized, False otherwise
"""

for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
return True

return False


def module_type(module: Module) -> str:
"""
Gets a string representation of a module type

:module: pytorch module to get type of
:return: module type as a string
"""
return type(module).__name__


def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def calculate_compression_ratio(model: Module) -> float:
"""
Calculates the quantization compression ratio of a pytorch model, based on the
number of bits needed to represent the total weights in compressed form. Does not
take into account activation quantizatons.

:param model: pytorch module to calculate compression ratio for
:return: compression ratio of the whole model
"""
total_compressed = 0.0
total_uncompressed = 0.0
for name, submodule in tqdm(
iter_named_leaf_modules(model),
desc="Calculating quantization compression ratio",
):
for parameter in model.parameters():
try:
uncompressed_bits = torch.finfo(parameter.dtype).bits
except TypeError:
uncompressed_bits = torch.iinfo(parameter.dtype).bits
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits
else:
print(name)
num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights

return total_uncompressed / total_compressed
41 changes: 34 additions & 7 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@


from sparsetensors.quantization.lifecycle import apply_quantization_config
from sparsetensors.quantization.quant_config import QuantizationConfig
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from transformers import AutoModelForCausalLM


Expand All @@ -33,7 +36,9 @@ def test_apply_quantization_config_tinyllama():
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for module in model.modules():
for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
Expand All @@ -46,11 +51,36 @@ def test_apply_quantization_config_tinyllama():
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 155
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
assert num_embeddings == 1
assert num_rotary_embeddings == 22


def test_serialize_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

serialized_config = QuantizationConfig.from_pretrained(model)
assert len(serialized_config.config_groups) == 2
assert serialized_config.config_groups["group_0"].targets == ["Embedding"]
assert serialized_config.config_groups["group_0"].input_activations is None
assert serialized_config.config_groups["group_1"].targets == ["Linear"]
assert serialized_config.config_groups["group_1"].input_activations is not None
assert serialized_config.quantization_status == QuantizationStatus.FROZEN
assert serialized_config.format == "fakequant"
assert serialized_config.quant_method == "sparseml"
assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"]
assert serialized_config.global_compression_ratio > 1.0
assert serialized_config.global_compression_ratio < 8.0


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
Expand Down Expand Up @@ -105,9 +135,6 @@ def get_sample_tinyllama_quant_config():
"targets": ["Embedding"],
},
},
"ignore": ["LlamaRotaryEmbedding"],
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)


test_apply_quantization_config_tinyllama()
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_basic_config():
assert config.format == "fakequant"
assert config.quantization_status == QuantizationStatus.INITIALIZED
assert config.global_compression_ratio is None
assert config.ignore is None
assert isinstance(config.ignore, list) and len(config.ignore) == 0


def test_full_config():
Expand Down
Loading