Skip to content

Commit

Permalink
Improve fake_quant_with_min_max_vars (#20772)
Browse files Browse the repository at this point in the history
* Fix fake_quant_with_min_max_vars

* Add FakeQuantWithMinMaxVars operation and use shortcut for TF backend.
  • Loading branch information
james77777778 authored Jan 17, 2025
1 parent 69d7eff commit 6c9dfca
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 53 deletions.
4 changes: 1 addition & 3 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import quantize_and_dequantize
4 changes: 1 addition & 3 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import quantize_and_dequantize
118 changes: 94 additions & 24 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend import any_symbolic_tensors
from keras.src.backend.common.backend_utils import canonicalize_axis
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
from keras.src.ops.operation import Operation

"""Int8-related classes and methods"""

Expand Down Expand Up @@ -130,17 +133,20 @@ def get_config(self):

def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""
# Use higher precision for the computation.
compute_dtype = backend.result_type(min_range.dtype, "float32")
min_range = ops.cast(min_range, compute_dtype)
max_range = ops.cast(max_range, compute_dtype)

quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32")

quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")
quant_max = (1 << num_bits) - 1
quant_min = 0 if not narrow_range else 1
diff_range = ops.subtract(max_range, min_range)

# Calculate the scale and ensure it's positive
scale = ops.divide(
ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min)
)
scale = ops.divide(diff_range, quant_max - quant_min)

inv_scale = ops.reciprocal(scale)
# Re-calculate the inverse to avoid loss of precision
inv_scale = ops.divide(quant_max - quant_min, diff_range)

# Calculate the zero point from the min range
zero_point_from_min = quant_min - ops.divide(min_range, scale)
Expand All @@ -158,17 +164,37 @@ def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
return nudged_min, nudged_max, scale, inv_scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
class FakeQuantWithMinMaxVars(Operation):
def __init__(self, num_bits=8, narrow_range=False, axis=None):
super().__init__()
self.num_bits = num_bits
self.narrow_range = narrow_range
self.axis = axis

def call(self, inputs, min_vals, max_vals):
return fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits=self.num_bits,
narrow_range=self.narrow_range,
axis=self.axis,
)

def compute_output_spec(self, inputs, min_vals, max_vals):
return KerasTensor(inputs.shape, dtype=inputs.dtype)


@keras_export("keras.quantizers.fake_quant_with_min_max_vars")
def fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits,
num_bits=8,
narrow_range=False,
axis=None,
):
"""
Perform per-tensor or per-channel fake quantization.
"""Perform per-tensor or per-channel fake quantization.
`[min_vals, max_vals]` define the clamping range for the `inputs`.
Expand All @@ -183,27 +209,68 @@ def fake_quant_with_min_max_vars(
`max_vals` to be trained.
Args:
inputs: Input tensor of float dtype.
inputs: Input Keras tensor of float dtype.
min_vals: A global minimum scalar or a per-channel minimum tensor.
max_vals: A global maximum scalar or a per-channel maximum tensor.
num_bits: Quantization bit width (e.g., `8` for int8).
narrow_range: Whether to use narrow quantization range.
num_bits: Quantization bit width (e.g., `8` for int8). Defaults to `8`.
narrow_range: Whether to use narrow quantization range. Defaults to
`False`.
axis: Axis along which to perform per-channel quantization. If `None`,
per-tensor quantization is performed. Defaults to `None`.
Returns:
Fake-quantized tensor
Tensor: A Keras tensor with fake quantization applied.
"""
if any_symbolic_tensors((inputs,)):
return FakeQuantWithMinMaxVars().symbolic_call(
inputs, min_vals, max_vals
)

inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)
num_bits = int(num_bits)

if axis is not None:
axis = canonicalize_axis(axis, inputs.ndim)

# Shortcut for TensorFlow backend by using `tf.quantization.fake_quant_*`
# apis. This is necessary to be recognizable for the TFLite converter.
if backend.backend() == "tensorflow":
import tensorflow as tf

# `tf.quantization.fake_quant_*` only supports float32.
dtype = backend.standardize_dtype(inputs.dtype)
if axis is None:
outputs = tf.quantization.fake_quant_with_min_max_vars(
ops.cast(inputs, "float32"),
ops.cast(ops.reshape(min_vals, ()), "float32"),
ops.cast(ops.reshape(max_vals, ()), "float32"),
num_bits=num_bits,
narrow_range=narrow_range,
)
return ops.cast(outputs, dtype=dtype)
else:
# `tf.quantization.fake_quant_with_min_max_vars_per_channel` only
# supports the last channel for the per-channel quantization. We
# use `ops.swapaxes` for the pre- and post-processing.
last_axis = inputs.ndim - 1
inputs = ops.swapaxes(inputs, axis, last_axis)
outputs = tf.quantization.fake_quant_with_min_max_vars_per_channel(
ops.cast(inputs, "float32"),
ops.cast(min_vals, "float32"),
ops.cast(max_vals, "float32"),
num_bits=num_bits,
narrow_range=narrow_range,
)
outputs = ops.cast(outputs, dtype=dtype)
return ops.swapaxes(outputs, last_axis, axis)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
dtype = backend.standardize_dtype(x.dtype)

# Calculate quantization parameters for all channels at once
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
Expand All @@ -212,7 +279,9 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
quant_zero = ops.floor(
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
)
x_clamped = ops.clip(x, nudged_min, nudged_max)
x_clamped = ops.clip(
x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype)
)
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
result = ops.multiply(
ops.floor(
Expand All @@ -225,33 +294,34 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
),
scale,
)
result = ops.cast(result, dtype=dtype)

# Create gradient mask for all channels
masks = ops.cast(
(x >= nudged_min) & (x <= nudged_max),
dtype="float32",
masks = ops.logical_and(
ops.greater_equal(x, nudged_min), ops.less_equal(x, nudged_max)
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient for x
dx = ops.multiply(upstream, masks)
dx = ops.where(masks, upstream, 0.0)
axes = [i for i in range(len(dx.shape)) if i != axis]

# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= nudged_min, dtype="float32")
grad_min = ops.multiply(upstream, min_mask)
min_mask = ops.less_equal(x, nudged_min)
grad_min = ops.where(min_mask, upstream, 0.0)
if axis is not None:
grad_min = ops.sum(grad_min, axis=axes)
else:
grad_min = ops.sum(grad_min)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.multiply(upstream, max_mask)
max_mask = ops.greater_equal(x, nudged_max)
grad_max = ops.where(max_mask, upstream, 0.0)
if axis is not None:
grad_max = ops.sum(grad_max, axis=axes)
else:
Expand Down
91 changes: 68 additions & 23 deletions keras/src/quantizers/quantizers_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys

import pytest
from absl.testing import parameterized

from keras.src import backend
Expand Down Expand Up @@ -104,6 +107,17 @@ def test_quantize_and_dequantize(self):
# A loose assertion due to an expected quantization error
self.assertAllClose(qdq_values, values, atol=5e-1)

@parameterized.named_parameters(
("per_tensor", None),
("per_channel", -1),
)
def test_fake_quant_with_min_max_vars_symbolic(self, axis):
x = backend.KerasTensor((2, 3, 4))
y = quantizers.fake_quant_with_min_max_vars(x, -3.0, 3.0, axis=axis)

self.assertIsInstance(y, backend.KerasTensor)
self.assertEqual(y.shape, (2, 3, 4))

@parameterized.named_parameters(
[
{
Expand Down Expand Up @@ -334,7 +348,11 @@ def test_quantize_and_dequantize(self):
},
]
)
def test_op(
@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax", "torch"),
reason=f"{backend.backend()} doesn't support `custom_gradient`.",
)
def test_fake_quant_with_min_max_vars(
self,
input_mins,
input_maxs,
Expand Down Expand Up @@ -401,6 +419,8 @@ def test_op(
initial_gradients = ops.transpose(
ops.array(initial_gradients_list, dtype="float32")
)

# Test gradients.
if backend.backend() == "tensorflow":
import tensorflow as tf

Expand All @@ -420,58 +440,60 @@ def test_op(
)
return initial_gradients * tape.gradient(result, inputs)

gradients = test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
)
# test gradients
self.assertAllClose(gradients, expected_backprops_wrt_input)

if backend.backend() == "torch":
import torch

def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range):
def test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
):
# Create tensor and enable gradient tracking
inputs = torch.tensor(
inputs, dtype=torch.float32, requires_grad=True
)

# Apply the quantization operation
result = quantizers.fake_quant_with_min_max_vars(
inputs, input_mins, input_maxs, num_bits, narrow_range
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
)

# Compute gradients
result.backward(torch.ones_like(result))

return initial_gradients * inputs.grad

gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range
)
# test gradients
self.assertAllClose(gradients, expected_backprops_wrt_input)

if backend.backend() == "jax":
import jax

def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range):
def test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
):
# Define the function to compute gradients for
def quantize_fn(x):
return quantizers.fake_quant_with_min_max_vars(
x, input_mins, input_maxs, num_bits, narrow_range
x, input_mins, input_maxs, num_bits, narrow_range, axis
)

_, f_vjp = jax.vjp(quantize_fn, inputs)
# NOTE:python 3.10 input_gradients = f_vjp.args[0].args[0][0] !
input_gradients = f_vjp.args[0].args[0][1]

# NOTE: When python version >= 3.10, the gradients are at
# `f_vjp.args[0].args[0][0]`. Otherwise, they are at
# `f_vjp.args[0].args[0][1]`.
if sys.version_info >= (3, 10):
input_gradients = f_vjp.args[0].args[0][0]
else:
input_gradients = f_vjp.args[0].args[0][1]

return ops.multiply(initial_gradients, input_gradients)

gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range
)
# test gradients
gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range, axis
)
if backend.backend() != "jax" or not testing.jax_uses_gpu():
# JAX GPU produces less precise numbers, causing the CI to fail.
# For example, 127.5 / 255.0 results in 0.49999997 instead of 0.5.
self.assertAllClose(gradients, expected_backprops_wrt_input)

# Test outputs.
outputs = quantizers.fake_quant_with_min_max_vars(
inputs,
input_min,
Expand All @@ -481,3 +503,26 @@ def quantize_fn(x):
axis=axis,
)
self.assertAllClose(outputs, expected)

# Test bfloat16 & float16 dtype
outputs = quantizers.fake_quant_with_min_max_vars(
ops.cast(inputs, "bfloat16"),
input_min,
input_max,
num_bits=num_bits,
narrow_range=narrow_range,
axis=axis,
)
self.assertDType(outputs, "bfloat16")
self.assertAllClose(outputs, expected)

outputs = quantizers.fake_quant_with_min_max_vars(
ops.cast(inputs, "float16"),
input_min,
input_max,
num_bits=num_bits,
narrow_range=narrow_range,
axis=axis,
)
self.assertDType(outputs, "float16")
self.assertAllClose(outputs, expected)

0 comments on commit 6c9dfca

Please sign in to comment.