Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply quantization config implementation #4

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sparsetensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .quant_args import *
from .quant_config import *
from .quant_scheme import *
3 changes: 2 additions & 1 deletion src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
from .apply import *
120 changes: 120 additions & 0 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections import OrderedDict
from typing import Iterable, Optional, Tuple

from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


__all__ = [
"apply_quantization_config",
"apply_quantization_status",
]


def apply_quantization_config(model: Module, config: QuantizationConfig):
"""
Initializes the model for quantization in-place based on the given config

:param model: model to apply quantization config to
:param config: quantization config
"""
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme

# build list of layers to target to avoid mutating submodule dict during iteration
layer_quant_scheme_pairs = []
for name, submodule in _iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target]))

# apply current quantization status for each matched pair
for layer, scheme in layer_quant_scheme_pairs:
apply_quantization_status(
module=layer,
scheme=scheme,
status=config.quantization_status,
)


def apply_quantization_status(
module: Module, scheme: QuantizationScheme, status: QuantizationStatus
):
"""
Applies in place the quantization lifecycle up to the given status

:param module: module to apply quantization to
:param scheme: quantization scheme to apply
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
initialize_module_for_quantization(module, scheme)
if status >= QuantizationStatus.CALIBRATION:
set_module_for_calibration(module)
if status >= QuantizationStatus.FROZEN:
freeze_module_quantization(module)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def _find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
) -> Optional[str]:
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets
)


def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
# returns first element of target that matches value either
# exactly or as a regex after 're:'
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif target == value:
return target
return None
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module

Expand Down
7 changes: 5 additions & 2 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand All @@ -28,9 +28,12 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)

module.quantization_status = QuantizationStatus.FROZEN
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter

Expand Down
26 changes: 25 additions & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from sparsetensors.quantization.quant_scheme import QuantizationScheme


__all__ = ["QuantizationStatus", "QuantizationConfig"]
__all__ = [
"QuantizationStatus",
"QuantizationConfig",
"LIFECYCLE_ORDER",
]


class QuantizationStatus(Enum):
Expand All @@ -41,6 +45,26 @@ class QuantizationStatus(Enum):
FROZEN = "frozen"
COMPRESSED = "compressed"

@classmethod
def lifecycle_order(cls) -> List["QuantizationStatus"]:
"""
:return: list of correct quantization lifecycle order
"""
return

def __ge__(self, other):
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
QuantizationStatus.INITIALIZED,
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
QuantizationStatus.COMPRESSED,
]


class QuantizationConfig(BaseModel):
"""
Expand Down
13 changes: 0 additions & 13 deletions ...etensors/quantization/lifecycle/status.py → tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
113 changes: 113 additions & 0 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle import apply_quantization_config
from sparsetensors.quantization.quant_config import QuantizationConfig
from transformers import AutoModelForCausalLM


def test_apply_quantization_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

# check for correct application of quant config
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for module in model.modules():
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
_test_layer_quantization_status(module, inputs=True, weights=True)
elif module_type == "Embedding":
num_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=True)
elif module_type == "LlamaRotaryEmbedding":
num_rotary_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 155
assert num_embeddings == 1
assert num_rotary_embeddings == 22


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
assert hasattr(module, "quantization_scheme") == quantized
assert hasattr(module, "quantization_status") == quantized

# check inputs matches expected
assert hasattr(module, "input_scale") == inputs
assert hasattr(module, "input_zero_point") == inputs

# check weights matches expected
assert hasattr(module, "weight_scale") == weights
assert hasattr(module, "weight_zero_point") == weights


def get_tinyllama_model():
return AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
)


def get_sample_tinyllama_quant_config():
config_dict = {
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": None,
"config_groups": {
"group_1": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"targets": ["Linear"],
},
"group_2": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": False,
"strategy": "tensor",
},
"input_activations": None,
"targets": ["Embedding"],
},
},
"ignore": ["LlamaRotaryEmbedding"],
}
return QuantizationConfig.parse_obj(config_dict)


test_apply_quantization_config_tinyllama()
Loading