From 6c9dfcafae24711a07391b68e4abe89019662e90 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Sat, 18 Jan 2025 06:00:46 +0800 Subject: [PATCH] Improve `fake_quant_with_min_max_vars` (#20772) * Fix fake_quant_with_min_max_vars * Add FakeQuantWithMinMaxVars operation and use shortcut for TF backend. --- .../_tf_keras/keras/quantizers/__init__.py | 4 +- keras/api/quantizers/__init__.py | 4 +- keras/src/quantizers/quantizers.py | 118 ++++++++++++++---- keras/src/quantizers/quantizers_test.py | 91 ++++++++++---- 4 files changed, 164 insertions(+), 53 deletions(-) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 8b11f6a3d63..2a6a083a099 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -12,7 +12,5 @@ from keras.src.quantizers.quantizers import abs_max_quantize from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel, -) +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 8b11f6a3d63..2a6a083a099 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -12,7 +12,5 @@ from keras.src.quantizers.quantizers import abs_max_quantize from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel, -) +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 2f7db0c9787..26ae800ce8f 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -4,8 +4,11 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors from keras.src.backend.common.backend_utils import canonicalize_axis from keras.src.backend.common.backend_utils import standardize_axis_for_numpy +from keras.src.ops.operation import Operation """Int8-related classes and methods""" @@ -130,17 +133,20 @@ def get_config(self): def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): """Adjusts and nudges the quantization range for better accuracy.""" + # Use higher precision for the computation. + compute_dtype = backend.result_type(min_range.dtype, "float32") + min_range = ops.cast(min_range, compute_dtype) + max_range = ops.cast(max_range, compute_dtype) - quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32") - - quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32") + quant_max = (1 << num_bits) - 1 + quant_min = 0 if not narrow_range else 1 + diff_range = ops.subtract(max_range, min_range) # Calculate the scale and ensure it's positive - scale = ops.divide( - ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min) - ) + scale = ops.divide(diff_range, quant_max - quant_min) - inv_scale = ops.reciprocal(scale) + # Re-calculate the inverse to avoid loss of precision + inv_scale = ops.divide(quant_max - quant_min, diff_range) # Calculate the zero point from the min range zero_point_from_min = quant_min - ops.divide(min_range, scale) @@ -158,17 +164,37 @@ def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): return nudged_min, nudged_max, scale, inv_scale -@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel") +class FakeQuantWithMinMaxVars(Operation): + def __init__(self, num_bits=8, narrow_range=False, axis=None): + super().__init__() + self.num_bits = num_bits + self.narrow_range = narrow_range + self.axis = axis + + def call(self, inputs, min_vals, max_vals): + return fake_quant_with_min_max_vars( + inputs, + min_vals, + max_vals, + num_bits=self.num_bits, + narrow_range=self.narrow_range, + axis=self.axis, + ) + + def compute_output_spec(self, inputs, min_vals, max_vals): + return KerasTensor(inputs.shape, dtype=inputs.dtype) + + +@keras_export("keras.quantizers.fake_quant_with_min_max_vars") def fake_quant_with_min_max_vars( inputs, min_vals, max_vals, - num_bits, + num_bits=8, narrow_range=False, axis=None, ): - """ - Perform per-tensor or per-channel fake quantization. + """Perform per-tensor or per-channel fake quantization. `[min_vals, max_vals]` define the clamping range for the `inputs`. @@ -183,27 +209,68 @@ def fake_quant_with_min_max_vars( `max_vals` to be trained. Args: - inputs: Input tensor of float dtype. + inputs: Input Keras tensor of float dtype. min_vals: A global minimum scalar or a per-channel minimum tensor. max_vals: A global maximum scalar or a per-channel maximum tensor. - num_bits: Quantization bit width (e.g., `8` for int8). - narrow_range: Whether to use narrow quantization range. + num_bits: Quantization bit width (e.g., `8` for int8). Defaults to `8`. + narrow_range: Whether to use narrow quantization range. Defaults to + `False`. axis: Axis along which to perform per-channel quantization. If `None`, per-tensor quantization is performed. Defaults to `None`. Returns: - Fake-quantized tensor + Tensor: A Keras tensor with fake quantization applied. """ + if any_symbolic_tensors((inputs,)): + return FakeQuantWithMinMaxVars().symbolic_call( + inputs, min_vals, max_vals + ) + inputs = ops.convert_to_tensor(inputs) min_vals = ops.convert_to_tensor(min_vals) max_vals = ops.convert_to_tensor(max_vals) + num_bits = int(num_bits) if axis is not None: axis = canonicalize_axis(axis, inputs.ndim) + # Shortcut for TensorFlow backend by using `tf.quantization.fake_quant_*` + # apis. This is necessary to be recognizable for the TFLite converter. + if backend.backend() == "tensorflow": + import tensorflow as tf + + # `tf.quantization.fake_quant_*` only supports float32. + dtype = backend.standardize_dtype(inputs.dtype) + if axis is None: + outputs = tf.quantization.fake_quant_with_min_max_vars( + ops.cast(inputs, "float32"), + ops.cast(ops.reshape(min_vals, ()), "float32"), + ops.cast(ops.reshape(max_vals, ()), "float32"), + num_bits=num_bits, + narrow_range=narrow_range, + ) + return ops.cast(outputs, dtype=dtype) + else: + # `tf.quantization.fake_quant_with_min_max_vars_per_channel` only + # supports the last channel for the per-channel quantization. We + # use `ops.swapaxes` for the pre- and post-processing. + last_axis = inputs.ndim - 1 + inputs = ops.swapaxes(inputs, axis, last_axis) + outputs = tf.quantization.fake_quant_with_min_max_vars_per_channel( + ops.cast(inputs, "float32"), + ops.cast(min_vals, "float32"), + ops.cast(max_vals, "float32"), + num_bits=num_bits, + narrow_range=narrow_range, + ) + outputs = ops.cast(outputs, dtype=dtype) + return ops.swapaxes(outputs, last_axis, axis) + @ops.custom_gradient def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): + dtype = backend.standardize_dtype(x.dtype) + # Calculate quantization parameters for all channels at once nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge( min_val, max_val, num_bits, narrow_range @@ -212,7 +279,9 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): quant_zero = ops.floor( ops.add(ops.multiply(-nudged_min, inv_scale), 0.5) ) - x_clamped = ops.clip(x, nudged_min, nudged_max) + x_clamped = ops.clip( + x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype) + ) x_clamped_shifted = ops.subtract(x_clamped, nudged_min) result = ops.multiply( ops.floor( @@ -225,11 +294,11 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): ), scale, ) + result = ops.cast(result, dtype=dtype) # Create gradient mask for all channels - masks = ops.cast( - (x >= nudged_min) & (x <= nudged_max), - dtype="float32", + masks = ops.logical_and( + ops.greater_equal(x, nudged_min), ops.less_equal(x, nudged_max) ) def grad(*args, upstream=None): @@ -237,12 +306,13 @@ def grad(*args, upstream=None): (upstream,) = args # Gradient for x - dx = ops.multiply(upstream, masks) + dx = ops.where(masks, upstream, 0.0) axes = [i for i in range(len(dx.shape)) if i != axis] + # Gradient for min_val # When x is clipped to min, the gradient flows to min_val - min_mask = ops.cast(x <= nudged_min, dtype="float32") - grad_min = ops.multiply(upstream, min_mask) + min_mask = ops.less_equal(x, nudged_min) + grad_min = ops.where(min_mask, upstream, 0.0) if axis is not None: grad_min = ops.sum(grad_min, axis=axes) else: @@ -250,8 +320,8 @@ def grad(*args, upstream=None): # Gradient for max_val # When x is clipped to max, the gradient flows to max_val - max_mask = ops.cast(x >= nudged_max, dtype="float32") - grad_max = ops.multiply(upstream, max_mask) + max_mask = ops.greater_equal(x, nudged_max) + grad_max = ops.where(max_mask, upstream, 0.0) if axis is not None: grad_max = ops.sum(grad_max, axis=axes) else: diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index d71f8fe2a1f..1fc7c94df7d 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -1,3 +1,6 @@ +import sys + +import pytest from absl.testing import parameterized from keras.src import backend @@ -104,6 +107,17 @@ def test_quantize_and_dequantize(self): # A loose assertion due to an expected quantization error self.assertAllClose(qdq_values, values, atol=5e-1) + @parameterized.named_parameters( + ("per_tensor", None), + ("per_channel", -1), + ) + def test_fake_quant_with_min_max_vars_symbolic(self, axis): + x = backend.KerasTensor((2, 3, 4)) + y = quantizers.fake_quant_with_min_max_vars(x, -3.0, 3.0, axis=axis) + + self.assertIsInstance(y, backend.KerasTensor) + self.assertEqual(y.shape, (2, 3, 4)) + @parameterized.named_parameters( [ { @@ -334,7 +348,11 @@ def test_quantize_and_dequantize(self): }, ] ) - def test_op( + @pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=f"{backend.backend()} doesn't support `custom_gradient`.", + ) + def test_fake_quant_with_min_max_vars( self, input_mins, input_maxs, @@ -401,6 +419,8 @@ def test_op( initial_gradients = ops.transpose( ops.array(initial_gradients_list, dtype="float32") ) + + # Test gradients. if backend.backend() == "tensorflow": import tensorflow as tf @@ -420,16 +440,12 @@ def test_op( ) return initial_gradients * tape.gradient(result, inputs) - gradients = test_op( - inputs, input_mins, input_maxs, num_bits, narrow_range, axis - ) - # test gradients - self.assertAllClose(gradients, expected_backprops_wrt_input) - if backend.backend() == "torch": import torch - def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): # Create tensor and enable gradient tracking inputs = torch.tensor( inputs, dtype=torch.float32, requires_grad=True @@ -437,7 +453,7 @@ def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): # Apply the quantization operation result = quantizers.fake_quant_with_min_max_vars( - inputs, input_mins, input_maxs, num_bits, narrow_range + inputs, input_mins, input_maxs, num_bits, narrow_range, axis ) # Compute gradients @@ -445,33 +461,39 @@ def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): return initial_gradients * inputs.grad - gradients = test_op( - inputs, input_min, input_max, num_bits, narrow_range - ) - # test gradients - self.assertAllClose(gradients, expected_backprops_wrt_input) - if backend.backend() == "jax": import jax - def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): # Define the function to compute gradients for def quantize_fn(x): return quantizers.fake_quant_with_min_max_vars( - x, input_mins, input_maxs, num_bits, narrow_range + x, input_mins, input_maxs, num_bits, narrow_range, axis ) _, f_vjp = jax.vjp(quantize_fn, inputs) - # NOTE:python 3.10 input_gradients = f_vjp.args[0].args[0][0] ! - input_gradients = f_vjp.args[0].args[0][1] + + # NOTE: When python version >= 3.10, the gradients are at + # `f_vjp.args[0].args[0][0]`. Otherwise, they are at + # `f_vjp.args[0].args[0][1]`. + if sys.version_info >= (3, 10): + input_gradients = f_vjp.args[0].args[0][0] + else: + input_gradients = f_vjp.args[0].args[0][1] return ops.multiply(initial_gradients, input_gradients) - gradients = test_op( - inputs, input_min, input_max, num_bits, narrow_range - ) - # test gradients + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range, axis + ) + if backend.backend() != "jax" or not testing.jax_uses_gpu(): + # JAX GPU produces less precise numbers, causing the CI to fail. + # For example, 127.5 / 255.0 results in 0.49999997 instead of 0.5. self.assertAllClose(gradients, expected_backprops_wrt_input) + + # Test outputs. outputs = quantizers.fake_quant_with_min_max_vars( inputs, input_min, @@ -481,3 +503,26 @@ def quantize_fn(x): axis=axis, ) self.assertAllClose(outputs, expected) + + # Test bfloat16 & float16 dtype + outputs = quantizers.fake_quant_with_min_max_vars( + ops.cast(inputs, "bfloat16"), + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertDType(outputs, "bfloat16") + self.assertAllClose(outputs, expected) + + outputs = quantizers.fake_quant_with_min_max_vars( + ops.cast(inputs, "float16"), + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertDType(outputs, "float16") + self.assertAllClose(outputs, expected)