Skip to content

Commit

Permalink
run bin.quant
Browse files Browse the repository at this point in the history
horheynm committed Apr 12, 2024
1 parent 560ef13 commit 0804be3
Showing 13 changed files with 82 additions and 448 deletions.
27 changes: 16 additions & 11 deletions bin/quant.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch
from torch.nn import Linear
# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs

from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparseml.modifiers.quantization.lifecycle.initialize import initialize_module_for_quantization
from sparseml.modifiers.quantization.lifecycle.calibration import set_module_for_calibration
from sparseml.modifiers.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
num_bits = 8

scheme = QuantizationScheme(
input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False),
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
output_activations=None,
targets = ["*"],
)

layer = Linear(4, 4)
@@ -31,25 +32,29 @@
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should have updated values
print(2)
for _ in range(10):
print("calib layers ")
for i in range(10):
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass

print(3)
breakpoint()
# breakpoint()


freeze_module_quantization(layer)
for _ in range(10):
print("freeze layers ")
for i in range(10):
# do more forward passes but show args are frozen
layer(torch.random.randn(4,4))
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should not be updated now


# missing
# # missing

# correctness
# quantizing an entire model
# # correctness
# # quantizing an entire model



2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"]

def _setup_extras() -> Dict:
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]}
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo"]}

setup(
name="sparsetensors",
1 change: 0 additions & 1 deletion src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -19,4 +19,3 @@
from .frozen import *
from .initialize import *
from .status import *
from .initialize import *
Original file line number Diff line number Diff line change
@@ -15,10 +15,9 @@

import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module

from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus


__all__ = [
"set_module_for_calibration",
@@ -41,4 +40,4 @@ def set_module_for_calibration(module: Module):
"to re-calibrate a frozen module"
)

module.quantization_status = QuantizationStatus.CALIBRATION
module.quantization_status = QuantizationStatus.CALIBRATION
21 changes: 13 additions & 8 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
@@ -15,11 +15,16 @@
from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus

# from sparsetensors.quantization.utils.quantization_scheme import (
# QuantizationArgs,
# QuantizationScheme,
# )
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module

from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus

from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs

__all__ = ["wrap_module_forward_quantized"]

@@ -34,8 +39,8 @@ def quantize(
torch.round(
x / scale + zero_point,
),
0,
q_max,
0,
q_max,
)


@@ -83,7 +88,7 @@ def fake_quantize(
# q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten()
# Q1[:, i] = q
# Q[:, i1:i2] = Q1
Q = quantize(x, scale, zero_point, max_q)
Q = quantize(x, scale, zero_point, max_q)
return dequantize(Q, scale, zero_point)


@@ -138,7 +143,7 @@ def _maybe_calibrate_or_quantize(
return value

scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

print(scale, zero_point)
@@ -152,4 +157,4 @@ def _maybe_calibrate_or_quantize(
scale.data = updated_scale
zero_point.data = updated_zero_point

return fake_quantize(value, scale, zero_point, args)
return fake_quantize(value, scale, zero_point, args)
3 changes: 1 addition & 2 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,9 @@
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module

from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus


__all__ = [
"freeze_module_quantization",
27 changes: 13 additions & 14 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
@@ -16,17 +16,17 @@
import logging

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus

# from sparsetensors.quantization.utils.quantization_scheme import (
# QuantizationArgs,
# QuantizationScheme,
# )
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter

from sparseml.modifiers.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationArgs,
QuantizationScheme,
)


__all__ = [
"initialize_module_for_quantization",
@@ -39,9 +39,7 @@
def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:

_initialize_scale_zero_point_observer(
module, "input", scheme.input_activations
)
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
@@ -52,7 +50,9 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem
f"for {type(module)}"
)
if scheme.output_activations is not None:
_initialize_scale_zero_point_observer(module, "output", scheme.output_activations)
_initialize_scale_zero_point_observer(
module, "output", scheme.output_activations
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED
@@ -61,7 +61,6 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem
wrap_module_forward_quantized(module, scheme)



def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -16,4 +16,4 @@

from .base import *
from .memoryless import *
from .min_max import *
from .min_max import *
12 changes: 5 additions & 7 deletions src/sparsetensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
@@ -14,12 +14,12 @@

from typing import Optional, Tuple

# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsezoo.utils.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module

from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs
from sparsezoo.utils.registry import RegistryMixin


__all__ = ["Observer"]

@@ -31,9 +31,7 @@ class Observer(Module, RegistryMixin):
pair
"""

def __init__(self,
quantization_args: QuantizationArgs
):
def __init__(self, quantization_args: QuantizationArgs):
self.quantization_args: QuantizationArgs = quantization_args
super().__init__()
self._scale = None
@@ -69,4 +67,4 @@ def get_qparams(
if observed is not None:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
return self._scale, self._zero_point
return self._scale, self._zero_point
7 changes: 4 additions & 3 deletions src/sparsetensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
@@ -15,10 +15,11 @@
from typing import Tuple

import torch
from sparsetensors.quantization.observers.base import Observer
from torch import FloatTensor, IntTensor, Tensor

from sparseml.modifiers.quantization.observers.base import Observer
# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs

# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs


__all__ = ["MemorylessObserver"]
@@ -60,4 +61,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:

zero_point = (0 - min_val) / scale

return scale, zero_point
return scale, zero_point
7 changes: 3 additions & 4 deletions src/sparsetensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
@@ -15,11 +15,10 @@
from typing import Tuple

import torch
from sparsetensors.quantization.observers.base import Observer
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch import FloatTensor, IntTensor, Tensor

from sparseml.modifiers.quantization.observers.base import Observer
from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs


__all__ = ["MinMaxObserver"]

@@ -77,4 +76,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:

zero_point = (0 - self.min_val) / scale

return scale, zero_point
return scale, zero_point
25 changes: 23 additions & 2 deletions src/sparsetensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,9 @@
# limitations under the License.

from enum import Enum
from typing import Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field


__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
@@ -61,3 +61,24 @@ class QuantizationArgs(BaseModel):
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
block_structure: Optional[str] = None
observer: str = Field(
default="minmax",
description=(
"The class to use to compute the quantization params - scale and zero-point'"
),
)
observer_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=(
"optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"
),
)

def get_observer(self):
"""
:return: torch quantization FakeQuantize built based on these QuantizationArgs
"""
from sparsetensors.quantization.observers.base import Observer

return Observer.load_from_registry(self.observer, quantization_args=self)
391 changes: 0 additions & 391 deletions src/sparsetensors/quantization/utils/quantization_scheme.py

This file was deleted.

0 comments on commit 0804be3

Please sign in to comment.