Skip to content

Commit edc35a1

Browse files
author
Sara Adkins
authored
Serialize Config from Model (#7)
* Apply quantization config implementation * add TODO * integrate full lifecycle support, QuantizationStatus updates, add tinyllama test * fix comment * initial implementation * add unit test * cleanup is_quantized * clean up targets and ignore lists * global compression ratio and docstrings * make sure scale/zp on correct device * helper for model quantization
1 parent 514e4db commit edc35a1

File tree

6 files changed

+234
-19
lines changed

6 files changed

+234
-19
lines changed

src/sparsetensors/quantization/lifecycle/apply.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import re
1616
from collections import OrderedDict
17-
from typing import Iterable, Optional, Tuple
17+
from typing import Iterable, Optional
1818

1919
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
2020
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
@@ -25,6 +25,7 @@
2525
QuantizationConfig,
2626
QuantizationStatus,
2727
)
28+
from sparsetensors.quantization.utils import iter_named_leaf_modules
2829
from torch.nn import Module
2930

3031

@@ -76,14 +77,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
7677
model.apply(freeze_module_quantization)
7778

7879

79-
def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
80-
# yields modules that do not have any submodules
81-
# TODO: potentially expand to add list of allowed submodules such as observers
82-
for name, submodule in model.named_modules():
83-
if len(list(submodule.children())) == 0:
84-
yield name, submodule
85-
86-
8780
def _find_first_name_or_class_match(
8881
name: str,
8982
module: Module,

src/sparsetensors/quantization/quant_config.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@
1515
from enum import Enum
1616
from typing import Dict, List, Optional
1717

18-
from pydantic import BaseModel
18+
from pydantic import BaseModel, Field
1919
from sparsetensors.quantization.quant_scheme import QuantizationScheme
20+
from sparsetensors.quantization.utils import (
21+
calculate_compression_ratio,
22+
is_module_quantized,
23+
iter_named_leaf_modules,
24+
module_type,
25+
)
26+
from torch.nn import Module
2027

2128

2229
__all__ = [
@@ -89,4 +96,59 @@ class QuantizationConfig(BaseModel):
8996
format: str = "fakequant"
9097
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
9198
global_compression_ratio: Optional[float] = None
92-
ignore: Optional[List[str]] = None
99+
ignore: Optional[List[str]] = Field(default_factory=list)
100+
101+
@staticmethod
102+
def from_pretrained(model: Module) -> "QuantizationConfig":
103+
"""
104+
Converts a model into its associated QuantizationConfig based on the
105+
QuantizationScheme attached to each quanitzed module
106+
107+
:param model: model to calculate quantization scheme of
108+
:return: filled out QuantizationScheme for the input model
109+
"""
110+
quant_scheme_to_layers = []
111+
quantization_status = None
112+
ignore = {}
113+
quantization_type_names = set()
114+
for name, submodule in iter_named_leaf_modules(model):
115+
layer_type = module_type(submodule)
116+
if not is_module_quantized(submodule):
117+
if layer_type not in ignore:
118+
ignore[layer_type] = []
119+
ignore[layer_type].append(name)
120+
else:
121+
quantization_status = submodule.quantization_status
122+
scheme = submodule.quantization_scheme
123+
quantization_type_names.add(layer_type)
124+
125+
match_found = False
126+
for existing_scheme in quant_scheme_to_layers:
127+
if scheme == existing_scheme:
128+
match_found = True
129+
break
130+
if not match_found:
131+
quant_scheme_to_layers.append(scheme)
132+
133+
# clean up ignore list, we can leave out layers types if none of the
134+
# instances are quantized
135+
consolidated_ignore = []
136+
for layer_type, ignore_names in ignore.items():
137+
if layer_type in quantization_type_names:
138+
# specific layers of a quantized type are ignored
139+
consolidated_ignore += ignore_names
140+
# else we leave it off the ignore list, doesn't fall under any of the
141+
# existing quantization schemes so it won't be quantized
142+
143+
config_groups = {}
144+
for idx, scheme in enumerate(quant_scheme_to_layers):
145+
group_name = "group_" + str(idx)
146+
config_groups[group_name] = scheme
147+
148+
compression_ratio = calculate_compression_ratio(model)
149+
return QuantizationConfig(
150+
config_groups=config_groups,
151+
quantization_status=quantization_status,
152+
global_compression_ratio=compression_ratio,
153+
ignore=consolidated_ignore,
154+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
from .helpers import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Tuple
16+
17+
import torch
18+
from torch.nn import Module
19+
from tqdm import tqdm
20+
21+
22+
__all__ = [
23+
"is_module_quantized",
24+
"is_model_quantized",
25+
"iter_named_leaf_modules",
26+
"module_type",
27+
"calculate_compression_ratio",
28+
]
29+
30+
31+
def is_module_quantized(module: Module) -> bool:
32+
"""
33+
Check if a module is quantized, based on the existence of a non-empty quantization
34+
scheme
35+
36+
:param module: pytorch module to check
37+
:return: True if module is quantized, False otherwise
38+
"""
39+
if not hasattr(module, "quantization_scheme"):
40+
return False
41+
42+
if module.quantization_scheme.weights is not None:
43+
return True
44+
45+
if module.quantization_scheme.input_activations is not None:
46+
return True
47+
48+
if module.quantization_scheme.output_activations is not None:
49+
return True
50+
51+
return False
52+
53+
54+
def is_model_quantized(model: Module) -> bool:
55+
"""
56+
Check if any modules in a model are quantized, based on the existence of a non-empty
57+
quantization scheme in at least one module
58+
59+
:param model: pytorch model
60+
:return: True if model is quantized, False otherwise
61+
"""
62+
63+
for _, submodule in iter_named_leaf_modules(model):
64+
if is_module_quantized(submodule):
65+
return True
66+
67+
return False
68+
69+
70+
def module_type(module: Module) -> str:
71+
"""
72+
Gets a string representation of a module type
73+
74+
:module: pytorch module to get type of
75+
:return: module type as a string
76+
"""
77+
return type(module).__name__
78+
79+
80+
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
81+
# yields modules that do not have any submodules
82+
# TODO: potentially expand to add list of allowed submodules such as observers
83+
for name, submodule in model.named_modules():
84+
if len(list(submodule.children())) == 0:
85+
yield name, submodule
86+
87+
88+
def calculate_compression_ratio(model: Module) -> float:
89+
"""
90+
Calculates the quantization compression ratio of a pytorch model, based on the
91+
number of bits needed to represent the total weights in compressed form. Does not
92+
take into account activation quantizatons.
93+
94+
:param model: pytorch module to calculate compression ratio for
95+
:return: compression ratio of the whole model
96+
"""
97+
total_compressed = 0.0
98+
total_uncompressed = 0.0
99+
for name, submodule in tqdm(
100+
iter_named_leaf_modules(model),
101+
desc="Calculating quantization compression ratio",
102+
):
103+
for parameter in model.parameters():
104+
try:
105+
uncompressed_bits = torch.finfo(parameter.dtype).bits
106+
except TypeError:
107+
uncompressed_bits = torch.iinfo(parameter.dtype).bits
108+
compressed_bits = uncompressed_bits
109+
if is_module_quantized(submodule):
110+
compressed_bits = submodule.quantization_scheme.weights.num_bits
111+
else:
112+
print(name)
113+
num_weights = parameter.numel()
114+
total_compressed += compressed_bits * num_weights
115+
total_uncompressed += uncompressed_bits * num_weights
116+
117+
return total_uncompressed / total_compressed

tests/quantization/lifecycle/test_apply.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515

1616
from sparsetensors.quantization.lifecycle import apply_quantization_config
17-
from sparsetensors.quantization.quant_config import QuantizationConfig
17+
from sparsetensors.quantization.quant_config import (
18+
QuantizationConfig,
19+
QuantizationStatus,
20+
)
1821
from transformers import AutoModelForCausalLM
1922

2023

@@ -33,7 +36,9 @@ def test_apply_quantization_config_tinyllama():
3336
num_linears = 0
3437
num_embeddings = 0
3538
num_rotary_embeddings = 0
36-
for module in model.modules():
39+
for name, module in model.named_modules():
40+
if name in quant_config.ignore:
41+
continue
3742
module_type = module.__class__.__name__
3843
if module_type == "Linear":
3944
num_linears += 1
@@ -46,11 +51,36 @@ def test_apply_quantization_config_tinyllama():
4651
_test_layer_quantization_status(module, inputs=False, weights=False)
4752

4853
# sanity check correct number of layers targeted
49-
assert num_linears == 155
54+
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
5055
assert num_embeddings == 1
5156
assert num_rotary_embeddings == 22
5257

5358

59+
def test_serialize_config_tinyllama():
60+
quant_config = get_sample_tinyllama_quant_config()
61+
model = get_tinyllama_model()
62+
63+
# check that model is not already quantized
64+
for module in model.modules():
65+
_test_layer_quantization_status(module, inputs=False, weights=False)
66+
67+
# apply quant config to model
68+
apply_quantization_config(model, quant_config)
69+
70+
serialized_config = QuantizationConfig.from_pretrained(model)
71+
assert len(serialized_config.config_groups) == 2
72+
assert serialized_config.config_groups["group_0"].targets == ["Embedding"]
73+
assert serialized_config.config_groups["group_0"].input_activations is None
74+
assert serialized_config.config_groups["group_1"].targets == ["Linear"]
75+
assert serialized_config.config_groups["group_1"].input_activations is not None
76+
assert serialized_config.quantization_status == QuantizationStatus.FROZEN
77+
assert serialized_config.format == "fakequant"
78+
assert serialized_config.quant_method == "sparseml"
79+
assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"]
80+
assert serialized_config.global_compression_ratio > 1.0
81+
assert serialized_config.global_compression_ratio < 8.0
82+
83+
5484
def _test_layer_quantization_status(module, inputs: bool, weights: bool):
5585
# check if quantization is applied at all (true if inputs or weights targeted)
5686
quantized = inputs or weights
@@ -105,9 +135,6 @@ def get_sample_tinyllama_quant_config():
105135
"targets": ["Embedding"],
106136
},
107137
},
108-
"ignore": ["LlamaRotaryEmbedding"],
138+
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
109139
}
110140
return QuantizationConfig.parse_obj(config_dict)
111-
112-
113-
test_apply_quantization_config_tinyllama()

tests/quantization/test_quant_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_basic_config():
3131
assert config.format == "fakequant"
3232
assert config.quantization_status == QuantizationStatus.INITIALIZED
3333
assert config.global_compression_ratio is None
34-
assert config.ignore is None
34+
assert isinstance(config.ignore, list) and len(config.ignore) == 0
3535

3636

3737
def test_full_config():

0 commit comments

Comments
 (0)