diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index d81c9dfab..ba7d127be 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -138,7 +138,13 @@ def quantization_memory_requirement(model: torch.nn.Module) -> int: for param in module.parameters(): # assume the max of group 128 and static scale/zp # TODO: base this on the recipe instead instead of assuming max - max_quant_shape = param.shape[0] * param.shape[1] // 128 + + # potentially just bias term + max_quant_shape = param.shape[0] // 128 + + if len(param.size()) > 1: # weights + max_quant_shape *= param.shape[1] + total_elements += max_quant_shape * 4 bytes_ratio = 32 // 16 # assuming float16