Skip to content

Commit

Permalink
validation + update name
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Oct 11, 2024
1 parent 0e52e66 commit 2d5f667
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
from compressed_tensors.quantization.observers.helpers import (
calculate_range,
compute_memoryless_zp_and_scales,
compute_dynamic_scales_and_zp,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
Expand Down Expand Up @@ -380,7 +380,7 @@ def maybe_calibrate_or_quantize(

if args.dynamic:
# dynamic quantization - no need to invoke observer
scale, zero_point = compute_memoryless_zp_and_scales(value=value, args=args)
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
Expand Down
7 changes: 4 additions & 3 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
"calculate_qparams",
"get_observer_token_count",
"calculate_range",
"compute_memoryless_zp_and_scales",
"compute_dynamic_scales_and_zp",
]


def compute_memoryless_zp_and_scales(value: Tensor, args: QuantizationArgs):
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
"""
Returns the min and max values of observed tensor
Returns the computed scales and zero points for dynamic activation
qunatization.
:param value: tensor to calculate quantization parameters for
:param args: quantization args
Expand Down
24 changes: 24 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from enum import Enum
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -171,6 +172,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
strategy = model.strategy
group_size = model.group_size
actorder = model.actorder
dynamic = model.dynamic
observer = model.observer

# infer strategy
if strategy is None:
Expand Down Expand Up @@ -207,6 +210,27 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
"activation ordering"
)

# if we have not set an observer and we
# are running static quantization, use minmax
if not observer and not dynamic:
model.observer = "minmax"

if dynamic:
if strategy not in (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
):
raise ValueError(
f"One of {QuantizationStrategy.TOKEN} or "
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
"quantization",
)
if observer is not None:
warnings.warn(
"No observer is used for dynamic quantization, setting to None"
)
model.observer = None

# write back modified values
model.strategy = strategy
return model
Expand Down

0 comments on commit 2d5f667

Please sign in to comment.