From bee1fad5f0dbff5ce4c8eacdc3051b16a55aff16 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 15 Apr 2024 15:06:21 +0000 Subject: [PATCH] draft tests before meeting --- bin/quant.py | 54 +++++++ .../quantization/lifecycle/forward.py | 5 +- .../lifecycle/test_calibration.py | 46 ++++++ .../quantization/lifecycle/test_end_to_end.py | 132 ++++++++++++++++++ .../quantization/lifecycle/test_forward.py | 64 +++++++++ 5 files changed, 297 insertions(+), 4 deletions(-) create mode 100644 bin/quant.py create mode 100644 tests/sparsetensors/quantization/lifecycle/test_calibration.py create mode 100644 tests/sparsetensors/quantization/lifecycle/test_end_to_end.py create mode 100644 tests/sparsetensors/quantization/lifecycle/test_forward.py diff --git a/bin/quant.py b/bin/quant.py new file mode 100644 index 00000000..513e04c9 --- /dev/null +++ b/bin/quant.py @@ -0,0 +1,54 @@ +import torch +from torch.nn import Linear + +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration +from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization +num_bits = 8 + +scheme = QuantizationScheme( + input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + output_activations=None, + targets = ["*"], +) + +layer = Linear(4, 4) +print(layer) +print(dict(layer.named_parameters())) + + +initialize_module_for_quantization(layer, scheme) +print(layer) # should see observer under layer now +print(0) +print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now +print(1) + + +set_module_for_calibration(layer) +# do a calibration step +layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values +print(2) +print("calib layers ") +for i in range(10): + print("iter", i) + layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass + +print(3) +# breakpoint() + + +freeze_module_quantization(layer) +print("freeze layers ") +for i in range(10): + # do more forward passes but show args are frozen + print("iter", i) + layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should not be updated now + + +# # missing \ No newline at end of file diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index ab20e29b..42921429 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -21,7 +21,7 @@ from torch.nn import Module -__all__ = ["wrap_module_forward_quantized"] +__all__ = ["wrap_module_forward_quantized","quantize","dequantize","fake_quantize"] def quantize( @@ -67,7 +67,6 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): @wraps(forward_func_orig) # ensures docstring, names, etc are propagated def wrapped_forward(self, *args, **kwargs): input_ = args[0] - if scheme.input_activations is not None: # calibrate and (fake) quantize input activations when applicable input_ = _maybe_calibrate_or_quantize( @@ -113,8 +112,6 @@ def _maybe_calibrate_or_quantize( # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") - print(scale, zero_point) - if module.quantization_status == QuantizationStatus.CALIBRATION: # get observer and get new quant params from observation observer = getattr(module, f"{base_name}_observer") diff --git a/tests/sparsetensors/quantization/lifecycle/test_calibration.py b/tests/sparsetensors/quantization/lifecycle/test_calibration.py new file mode 100644 index 00000000..172b586a --- /dev/null +++ b/tests/sparsetensors/quantization/lifecycle/test_calibration.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import pytest +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from torch.nn import Linear + + +@pytest.fixture(scope="module") +def create_quantization_scheme(): + def quantization_scheme( + targets: List[str], + weights: Optional[QuantizationArgs] = None, + input_activations: Optional[QuantizationArgs] = None, + output_activations: Optional[QuantizationArgs] = None, + ): + return QuantizationScheme( + targets=targets, + weights=weights, + input_activations=input_activations, + output_activations=output_activations, + ) + + return quantization_scheme + + +def test_set_module_for_calibration(create_quantization_scheme): + quantization_scheme = create_quantization_scheme( + targets=["*"], + ) + + layer = Linear(4, 4) diff --git a/tests/sparsetensors/quantization/lifecycle/test_end_to_end.py b/tests/sparsetensors/quantization/lifecycle/test_end_to_end.py new file mode 100644 index 00000000..4112e334 --- /dev/null +++ b/tests/sparsetensors/quantization/lifecycle/test_end_to_end.py @@ -0,0 +1,132 @@ +import torch +from torch.nn import Linear + +from typing import Optional, List +import pytest +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration +from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization +from sparsetensors.quantization.lifecycle.status import QuantizationStatus + + +@pytest.fixture(scope="module") +def create_quantization_scheme(): + def quantization_scheme( + targets: List[str], + weights: Optional[QuantizationArgs] = None, + input_activations: Optional[QuantizationArgs] = None, + output_activations: Optional[QuantizationArgs] = None, + ): + return QuantizationScheme( + targets=targets, + weights=weights, + input_activations=input_activations, + output_activations=output_activations, + ) + + return quantization_scheme + + +def test_lifecyle(create_quantization_scheme): + num_bits = 8 + + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + + layer = Linear(4, 4) + layer.weight.data *= 100 + + # updated layer keys check + expected_layer_keys = {"weight", "bias"} + for key in layer.state_dict().keys(): + expected_layer_keys.remove(key) + assert len(expected_layer_keys) == 0 + + + initialize_module_for_quantization(layer, quantization_scheme) + expected_layer_keys = { + "input_scale", + "input_zero_point", + "weight_scale", + "weight_zero_point", + "weight", + "bias", + } + for key in layer.state_dict().keys(): + expected_layer_keys.remove(key) + assert len(expected_layer_keys) == 0 + + assert hasattr(layer, "quantization_scheme") + assert hasattr(layer, "quantization_status") + assert layer.quantization_status == QuantizationStatus.INITIALIZED + + set_module_for_calibration(layer) + assert layer.quantization_status == QuantizationStatus.CALIBRATION + + # do a calibration step + print(dict(layer.named_parameters())) # scale and zero point should have updated values + original_tensor = layer.weight.data + original_input_zero_point = layer.input_zero_point + original_input_scale = layer.input_scale + original_weight_scale = layer.weight_scale + original_weight_zero_point = layer.weight_zero_point + + print() + print() + print() + print() + print() + print() + + layer(torch.randn(4,4)) + + # zero-points and scale + updated_tensor = layer.weight.data + updated_input_zero_point = layer.input_zero_point + updated_input_scale = layer.input_scale + updated_weight_scale = layer.weight_scale + updated_weight_zero_point = layer.weight_zero_point + + print(original_tensor, updated_tensor) + print(original_input_zero_point, updated_input_zero_point) + print(original_input_scale, updated_input_scale) + print(original_weight_scale, updated_weight_scale) + print(original_weight_zero_point, updated_weight_zero_point) + + + breakpoint() + + + + + + + print(dict(layer.named_parameters())) # scale and zero point should have updated values + breakpoint() + + print(2) + print("calib layers ") + for i in range(10): + print("iter", i) + layer(torch.randn(4,4)) + print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass + + print(3) + # breakpoint() + + + freeze_module_quantization(layer) + print("freeze layers ") + for i in range(10): + # do more forward passes but show args are frozen + print("iter", i) + layer(torch.randn(4,4)) + print(dict(layer.named_parameters())) # scale and zero point should not be updated now + + + # # missing \ No newline at end of file diff --git a/tests/sparsetensors/quantization/lifecycle/test_forward.py b/tests/sparsetensors/quantization/lifecycle/test_forward.py new file mode 100644 index 00000000..fc9ea467 --- /dev/null +++ b/tests/sparsetensors/quantization/lifecycle/test_forward.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +import pytest +from sparsetensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from torch.nn import Linear + +from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized + + +@pytest.fixture(scope="module") +def create_quantization_scheme(): + def quantization_scheme( + targets: List[str], + weights: Optional[QuantizationArgs] = None, + input_activations: Optional[QuantizationArgs] = None, + output_activations: Optional[QuantizationArgs] = None, + ): + return QuantizationScheme( + targets=targets, + weights=weights, + input_activations=input_activations, + output_activations=output_activations, + ) + + return quantization_scheme + + +def test_wrap_module_forward_quantized__forward_overwrite(create_quantization_scheme): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + layer = Linear(4, 4) + + func_forward = layer.forward.__func__ + + # check that the forward call is overwritten + wrap_module_forward_quantized(layer, quantization_scheme) + + assert not func_forward == layer.forward.__func__ + + +def test_wrap_module_forward_quantized__forward_overwrite(create_quantization_scheme): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + layer = Linear(4, 4) + layer.weight.data *= 100 + + data =layer.weight.data + + wrap_module_forward_quantized(layer, quantization_scheme) + +