diff --git a/src/sparsetensors/quantization/__init__.py b/src/sparsetensors/quantization/__init__.py index 7227f889..9fde69a3 100644 --- a/src/sparsetensors/quantization/__init__.py +++ b/src/sparsetensors/quantization/__init__.py @@ -18,3 +18,4 @@ from .quant_args import * from .quant_config import * from .quant_scheme import * +from .lifecycle import * diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index b69c841d..5fd92a6e 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -56,6 +56,6 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: # scales from a 0 range should be set to 1 scale[observed_range == 0] = 1 - zero_point = (0 - min_val) / scale + zero_point = ((0 - min_val) / scale).to(torch.int8) return scale, zero_point diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index 40cde72c..e73805b4 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -48,7 +48,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) - # running average + # update running average if self.counter > 0: self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) @@ -57,23 +57,23 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: self.max_val = max_val # ensure that the zeros are in the range - self.min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - self.max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) self.counter += 1 if self.quantization_args.symmetric: - symmetric_range = 2 * max(self.min_val.abs(), self.max_val.abs()) + symmetric_range = 2 * max(min_val.abs(), max_val.abs()) scale = symmetric_range / bit_range zero_point = torch.tensor(0).to(torch.int8) else: # non-symmetric - observed_range = self.max_val - self.min_val + observed_range = max_val - min_val scale = observed_range / bit_range # scales from a 0 range should be set to 1 scale[observed_range == 0] = 1 - zero_point = (0 - self.min_val) / scale + zero_point = ((0 - min_val) / scale).to(torch.int8) return scale, zero_point