Skip to content

Commit

Permalink
[lifecycle] docstrings + ux update to work with torch.apply
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Apr 16, 2024
1 parent 20283a0 commit 0221639
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
28 changes: 9 additions & 19 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


Expand All @@ -49,41 +48,32 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
for target in scheme.targets:
target_to_scheme[target] = scheme

# build list of layers to target to avoid mutating submodule dict during iteration
layer_quant_scheme_pairs = []
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in _iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target]))
submodule.quantization_scheme = target_to_scheme[target]

# apply current quantization status for each matched pair
for layer, scheme in layer_quant_scheme_pairs:
apply_quantization_status(
module=layer,
scheme=scheme,
status=config.quantization_status,
)
# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)


def apply_quantization_status(
module: Module, scheme: QuantizationScheme, status: QuantizationStatus
):
def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status
:param module: module to apply quantization to
:param scheme: quantization scheme to apply
:param model: model to apply quantization to
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
initialize_module_for_quantization(module, scheme)
model.apply(initialize_module_for_quantization)
if status >= QuantizationStatus.CALIBRATION:
set_module_for_calibration(module)
model.apply(set_module_for_calibration)
if status >= QuantizationStatus.FROZEN:
freeze_module_quantization(module)
model.apply(freeze_module_quantization)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
Expand Down
8 changes: 8 additions & 0 deletions src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@


def set_module_for_calibration(module: Module):
"""
marks a layer as ready for calibration which activates observers
to update scales and zero points on each forward pass
apply to full model with `model.apply(set_module_for_calibration)`
:param module: module to set for calibration
"""
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return
Expand Down
7 changes: 7 additions & 0 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@


def freeze_module_quantization(module: Module):
"""
deletes observers so static quantization is completed.
apply to full model with `model.apply(freeze_module_quantization)`
:param module: module to freeze quantization for
"""
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return
Expand Down
23 changes: 21 additions & 2 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import logging
from typing import Optional

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
Expand All @@ -31,9 +32,27 @@
_LOGGER = logging.getLogger(__name__)


def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:
def initialize_module_for_quantization(
module: Module,
scheme: Optional[QuantizationScheme] = None,
):
"""
attaches appropriate scales, zero points, and observers to a layer
given its target quantization scheme
apply to full model with `model.apply(initialize_module_for_quantization)`
:param module: module to set for calibration
:param scheme: scheme to use for quantization. if None is provided,
will attempt to use scheme stored in the module under `quantization_scheme`,
if not provided, the layer will be skipped
"""
scheme = scheme or getattr(module, "quantization_scheme", None)
if scheme is None:
# no scheme passed and layer not targeted for quantization - skip
return

if scheme.input_activations is not None:
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
Expand Down

0 comments on commit 0221639

Please sign in to comment.