diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py index d0362b8f..7e7ea908 100644 --- a/src/sparsetensors/quantization/observers/__init__.py +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. # flake8: noqa +# isort: skip_file +from .helpers import * from .base import * from .memoryless import * from .min_max import * diff --git a/src/sparsetensors/quantization/observers/helpers.py b/src/sparsetensors/quantization/observers/helpers.py new file mode 100644 index 00000000..0ec086a7 --- /dev/null +++ b/src/sparsetensors/quantization/observers/helpers.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from sparsetensors.quantization.quant_args import QuantizationArgs +from torch import FloatTensor, IntTensor, Tensor + + +__all__ = ["calculate_qparams"] + + +def calculate_qparams( + min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs +) -> Tuple[FloatTensor, IntTensor]: + """ + :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s) + from + :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s) + from + :param quantization_args: settings to quantization + :return: tuple of the calculated scale(s) and zero point(s) + """ + bit_range = 2**quantization_args.num_bits - 1 + if quantization_args.symmetric: + symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) + scales = symmetric_range / bit_range + zero_points = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = max_vals - min_vals + scales = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scales[observed_range == 0] = 1 + + zero_points = ((0 - min_vals) / scales).to(torch.int8) + + return scales, zero_points diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index 5fd92a6e..0a70d6d2 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -16,6 +16,7 @@ import torch from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.observers.helpers import calculate_qparams from torch import FloatTensor, IntTensor, Tensor @@ -36,7 +37,6 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - bit_range = 255 min_val = observed.min() max_val = observed.max() @@ -44,18 +44,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: min_val = torch.min(min_val, torch.zeros_like(min_val)) max_val = torch.max(max_val, torch.zeros_like(max_val)) - if self.quantization_args.symmetric: - symmetric_range = 2 * max(min_val.abs(), max_val.abs()) - scale = symmetric_range / bit_range - zero_point = torch.tensor(0).to(torch.int8) - else: - # non-symmetric - observed_range = max_val - min_val - scale = observed_range / bit_range - - # scales from a 0 range should be set to 1 - scale[observed_range == 0] = 1 - - zero_point = ((0 - min_val) / scale).to(torch.int8) - - return scale, zero_point + return calculate_qparams(min_val, max_val, self.quantization_args) diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index e73805b4..1d2b4dc7 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -16,6 +16,7 @@ import torch from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.observers.helpers import calculate_qparams from sparsetensors.quantization.quant_args import QuantizationArgs from torch import FloatTensor, IntTensor, Tensor @@ -44,7 +45,6 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - bit_range = 255 min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) @@ -62,18 +62,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: self.counter += 1 - if self.quantization_args.symmetric: - symmetric_range = 2 * max(min_val.abs(), max_val.abs()) - scale = symmetric_range / bit_range - zero_point = torch.tensor(0).to(torch.int8) - else: - # non-symmetric - observed_range = max_val - min_val - scale = observed_range / bit_range - - # scales from a 0 range should be set to 1 - scale[observed_range == 0] = 1 - - zero_point = ((0 - min_val) / scale).to(torch.int8) - - return scale, zero_point + return calculate_qparams(min_val, max_val, self.quantization_args)