Skip to content

Commit

Permalink
[Observers] pull shared logic into a helper function (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Apr 17, 2024
1 parent 7bbeb65 commit fa01b71
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 32 deletions.
2 changes: 2 additions & 0 deletions src/sparsetensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .helpers import *
from .base import *
from .memoryless import *
from .min_max import *
51 changes: 51 additions & 0 deletions src/sparsetensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["calculate_qparams"]


def calculate_qparams(
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
) -> Tuple[FloatTensor, IntTensor]:
"""
:param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
from
:param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
from
:param quantization_args: settings to quantization
:return: tuple of the calculated scale(s) and zero point(s)
"""
bit_range = 2**quantization_args.num_bits - 1
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
zero_points = torch.tensor(0).to(torch.int8)
else:
# non-symmetric
observed_range = max_vals - min_vals
scales = observed_range / bit_range

# scales from a 0 range should be set to 1
scales[observed_range == 0] = 1

zero_points = ((0 - min_vals) / scales).to(torch.int8)

return scales, zero_points
18 changes: 2 additions & 16 deletions src/sparsetensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from sparsetensors.quantization.observers.base import Observer
from sparsetensors.quantization.observers.helpers import calculate_qparams
from torch import FloatTensor, IntTensor, Tensor


Expand All @@ -36,26 +37,11 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
bit_range = 255
min_val = observed.min()
max_val = observed.max()

# ensure zero is in the range
min_val = torch.min(min_val, torch.zeros_like(min_val))
max_val = torch.max(max_val, torch.zeros_like(max_val))

if self.quantization_args.symmetric:
symmetric_range = 2 * max(min_val.abs(), max_val.abs())
scale = symmetric_range / bit_range
zero_point = torch.tensor(0).to(torch.int8)
else:
# non-symmetric
observed_range = max_val - min_val
scale = observed_range / bit_range

# scales from a 0 range should be set to 1
scale[observed_range == 0] = 1

zero_point = ((0 - min_val) / scale).to(torch.int8)

return scale, zero_point
return calculate_qparams(min_val, max_val, self.quantization_args)
18 changes: 2 additions & 16 deletions src/sparsetensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from sparsetensors.quantization.observers.base import Observer
from sparsetensors.quantization.observers.helpers import calculate_qparams
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch import FloatTensor, IntTensor, Tensor

Expand Down Expand Up @@ -44,7 +45,6 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
bit_range = 255
min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])

Expand All @@ -62,18 +62,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:

self.counter += 1

if self.quantization_args.symmetric:
symmetric_range = 2 * max(min_val.abs(), max_val.abs())
scale = symmetric_range / bit_range
zero_point = torch.tensor(0).to(torch.int8)
else:
# non-symmetric
observed_range = max_val - min_val
scale = observed_range / bit_range

# scales from a 0 range should be set to 1
scale[observed_range == 0] = 1

zero_point = ((0 - min_val) / scale).to(torch.int8)

return scale, zero_point
return calculate_qparams(min_val, max_val, self.quantization_args)

0 comments on commit fa01b71

Please sign in to comment.