Skip to content

Commit

Permalink
group size full lifecycle run
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Apr 23, 2024
1 parent 81954b6 commit 803f495
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
21 changes: 15 additions & 6 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def quantize(
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
torch.round(x / scale + zero_point),
q_min,
q_max,
)
Expand All @@ -60,9 +58,20 @@ def fake_quantize(
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)
# Q = torch.zeros_like(x)
DQ = torch.zeros_like(x)
num_groups = len(scale)
group_size = int(x.shape[1] / num_groups)
for i in range(num_groups):
sc = scale[i]
zp = zero_point[i]

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
breakpoint()
# Q = quantize(x, scale, zero_point, min_q, max_q)
return DQ


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down
13 changes: 7 additions & 6 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ def get_qparams(
"""
if observed is not None:
group_size = self.quantization_args.group_size
if group_size is None:

if group_size > 0: # quantize by groups
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
if hasattr(self, "inc"):
self.inc()

elif group_size > 0: # quantize by groups
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
Expand All @@ -89,9 +95,4 @@ def get_qparams(
if hasattr(self, "inc"):
self.inc()

else:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
if hasattr(self, "inc"):
self.inc()
return self._scale, self._zero_point
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def calculate_qparams(
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
zero_points = torch.tensor(0).to(torch.int8)
zero_points = torch.tensor([0]).to(torch.int8)
else:
# non-symmetric
observed_range = max_vals - min_vals
Expand Down

0 comments on commit 803f495

Please sign in to comment.