Skip to content

Commit

Permalink
integrate full lifecycle support, QuantizationStatus updates, add tin…
Browse files Browse the repository at this point in the history
…yllama test
  • Loading branch information
Benjamin committed Apr 12, 2024
1 parent 8465015 commit 24e04b6
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 33 deletions.
1 change: 0 additions & 1 deletion src/sparsetensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@
from .quant_args import *
from .quant_config import *
from .quant_scheme import *
from .apply import *
3 changes: 2 additions & 1 deletion src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
from .apply import *
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,23 @@
from collections import OrderedDict
from typing import Iterable, Optional, Tuple

from sparsetensors.quantization.quant_config import QuantizationConfig
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


__all__ = ["apply_quantization_config"]


# TODO: to be ported from sparseml, placeholder only for now
def initialize_module_for_quantization(module, scheme):
pass
__all__ = [
"apply_quantization_config",
"apply_quantization_status",
]


def apply_quantization_config(model: Module, config: QuantizationConfig):
Expand All @@ -42,21 +49,48 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
for target in scheme.targets:
target_to_scheme[target] = scheme

# build list of layers to target to avoid mutating submodule dict during iteration
layer_quant_scheme_pairs = []
for name, submodule in _iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched, initialize layer from the matched scheme
# TODO: add follow on lifecycle calls based on the quantization status
initialize_module_for_quantization(submodule, target_to_scheme[target])
# target matched - add layer and scheme to target list
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target]))

# apply current quantization status for each matched pair
for layer, scheme in layer_quant_scheme_pairs:
apply_quantization_status(
module=layer,
scheme=scheme,
status=config.quantization_status,
)


def apply_quantization_status(
module: Module, scheme: QuantizationScheme, status: QuantizationStatus
):
"""
Applies in place the quantization lifecycle up to the given status
:param module: module to apply quantization to
:param scheme: quantization scheme to apply
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
initialize_module_for_quantization(module, scheme)
if status >= QuantizationStatus.CALIBRATION:
set_module_for_calibration(module)
if status >= QuantizationStatus.FROZEN:
freeze_module_quantization(module)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(submodule.modules()) == 0:
if len(list(submodule.children())) == 0:
yield name, submodule


Expand All @@ -69,7 +103,7 @@ def _find_first_name_or_class_match(
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name, targets
module.__class__.__name__, targets
)


Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module

Expand Down
7 changes: 5 additions & 2 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand All @@ -28,9 +28,12 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)

module.quantization_status = QuantizationStatus.FROZEN
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter

Expand Down
26 changes: 25 additions & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from sparsetensors.quantization.quant_scheme import QuantizationScheme


__all__ = ["QuantizationStatus", "QuantizationConfig"]
__all__ = [
"QuantizationStatus",
"QuantizationConfig",
"LIFECYCLE_ORDER",
]


class QuantizationStatus(Enum):
Expand All @@ -41,6 +45,26 @@ class QuantizationStatus(Enum):
FROZEN = "frozen"
COMPRESSED = "compressed"

@classmethod
def lifecycle_order(cls) -> List["QuantizationStatus"]:
"""
:return: list of correct quantization lifecycle order
"""
return

def __ge__(self, other):
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
QuantizationStatus.INITIALIZED,
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
QuantizationStatus.COMPRESSED,
]


class QuantizationConfig(BaseModel):
"""
Expand Down
13 changes: 0 additions & 13 deletions ...etensors/quantization/lifecycle/status.py → tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
111 changes: 111 additions & 0 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle import apply_quantization_config
from sparsetensors.quantization.quant_config import QuantizationConfig
from transformers import AutoModelForCausalLM


def test_apply_quantization_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

# check for correct application of quant config
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for module in model.modules():
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
_test_layer_quantization_status(module, inputs=True, weights=True)
elif module_type == "Embedding":
num_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=True)
elif module_type == "LlamaRotaryEmbedding":
num_rotary_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 155
assert num_embeddings == 1
assert num_rotary_embeddings == 22


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
assert hasattr(module, "quantization_scheme") == quantized
assert hasattr(module, "quantization_status") == quantized

# check for inputs
assert hasattr(module, "input_scale") == inputs
assert hasattr(module, "input_zero_point") == inputs
assert hasattr(module, "weight_scale") == weights
assert hasattr(module, "weight_zero_point") == weights


def get_tinyllama_model():
return AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
)


def get_sample_tinyllama_quant_config():
config_dict = {
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": None,
"config_groups": {
"group_1": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"targets": ["Linear"],
},
"group_2": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": False,
"strategy": "tensor",
},
"input_activations": None,
"targets": ["Embedding"],
},
},
"ignore": ["LlamaRotaryEmbedding"],
}
return QuantizationConfig.parse_obj(config_dict)


test_apply_quantization_config_tinyllama()

0 comments on commit 24e04b6

Please sign in to comment.