diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index ccd232fa..67990dde 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -58,10 +58,12 @@ def fake_quantize( bit_range = 2**args.num_bits max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - # Q = torch.zeros_like(x) + DQ = torch.zeros_like(x) num_groups = len(scale) group_size = int(x.shape[1] / num_groups) + + # TODO: vectorize the for loop for i in range(num_groups): sc = scale[i] zp = zero_point[i] @@ -69,8 +71,7 @@ def fake_quantize( idx = i * group_size Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) - breakpoint() - # Q = quantize(x, scale, zero_point, min_q, max_q) + return DQ