Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize Config from Model #7

Merged
merged 14 commits into from
Apr 16, 2024
2 changes: 2 additions & 0 deletions src/sparsetensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .quant_args import *
from .quant_config import *
from .quant_scheme import *
3 changes: 2 additions & 1 deletion src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
from .apply import *
113 changes: 113 additions & 0 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections import OrderedDict
from typing import Iterable, Optional

from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.utils import iter_named_leaf_modules
from torch.nn import Module


__all__ = [
"apply_quantization_config",
"apply_quantization_status",
]


def apply_quantization_config(model: Module, config: QuantizationConfig):
"""
Initializes the model for quantization in-place based on the given config

:param model: model to apply quantization config to
:param config: quantization config
"""
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme

# build list of layers to target to avoid mutating submodule dict during iteration
layer_quant_scheme_pairs = []
for name, submodule in iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target]))

# apply current quantization status for each matched pair
for layer, scheme in layer_quant_scheme_pairs:
apply_quantization_status(
module=layer,
scheme=scheme,
status=config.quantization_status,
)


def apply_quantization_status(
module: Module, scheme: QuantizationScheme, status: QuantizationStatus
):
"""
Applies in place the quantization lifecycle up to the given status

:param module: module to apply quantization to
:param scheme: quantization scheme to apply
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
initialize_module_for_quantization(module, scheme)
if status >= QuantizationStatus.CALIBRATION:
set_module_for_calibration(module)
if status >= QuantizationStatus.FROZEN:
freeze_module_quantization(module)


def _find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
) -> Optional[str]:
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets
)


def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
# returns first element of target that matches value either
# exactly or as a regex after 're:'
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif target == value:
return target
return None
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module

Expand Down
7 changes: 5 additions & 2 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand All @@ -28,9 +28,12 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)

module.quantization_status = QuantizationStatus.FROZEN
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter

Expand Down
69 changes: 68 additions & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@

from pydantic import BaseModel
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
)
from torch.nn import Module


__all__ = ["QuantizationStatus", "QuantizationConfig"]
__all__ = [
"QuantizationStatus",
"QuantizationConfig",
"LIFECYCLE_ORDER",
]


class QuantizationStatus(Enum):
Expand All @@ -41,6 +50,26 @@ class QuantizationStatus(Enum):
FROZEN = "frozen"
COMPRESSED = "compressed"

@classmethod
def lifecycle_order(cls) -> List["QuantizationStatus"]:
"""
:return: list of correct quantization lifecycle order
"""
return

def __ge__(self, other):
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
QuantizationStatus.INITIALIZED,
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
QuantizationStatus.COMPRESSED,
]


class QuantizationConfig(BaseModel):
"""
Expand All @@ -66,3 +95,41 @@ class QuantizationConfig(BaseModel):
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = None

@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
TODO: fill in docstrings
"""
quant_scheme_to_layers = []
quantization_status = None
ignore = []
for name, submodule in iter_named_leaf_modules(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See TODO comment about allowing for exceptions in leaf nodes for observers. This will be relevant for non frozen quantized models

if not is_module_quantized(submodule):
ignore.append(name)
else:
quantization_status = submodule.quantization_status
scheme = submodule.quantization_scheme

match_found = False
for idx, (existing_scheme, _) in enumerate(quant_scheme_to_layers):
if scheme == existing_scheme:
match_found = True
quant_scheme_to_layers[idx][1].append(
name
) # append((name, module_type(submodule)))
break
if not match_found:
quant_scheme_to_layers.append((scheme, [name]))

config_groups = {}
for idx, (scheme, targets) in enumerate(quant_scheme_to_layers):
group_name = "group_" + str(idx)
scheme.targets = targets
config_groups[group_name] = scheme

return QuantizationConfig(
config_groups=config_groups,
quantization_status=quantization_status,
ignore=ignore,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
# flake8: noqa
from .helpers import *
36 changes: 36 additions & 0 deletions src/sparsetensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

from torch.nn import Module


__all__ = ["is_module_quantized", "iter_named_leaf_modules", "module_type"]


def is_module_quantized(module: Module) -> bool:
return hasattr(module, "quantization_scheme")
Satrat marked this conversation as resolved.
Show resolved Hide resolved


def module_type(module: Module) -> str:
return type(module).__name__


def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading