diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 38a4de6..88cd7d1 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -246,10 +246,27 @@ def quantize_weights( cleanup_memory() +def find_max_batch_size(model: AutoModelForCausalLM, tokens): + batch_size = tokens.shape[0] + while batch_size > 1: + try: + with torch.inference_mode(): + model(tokens[:batch_size].reshape(batch_size, -1)) + return batch_size + except RuntimeError as e: + print(e) + if 'out of memory' in str(e): + cleanup_memory() + batch_size //= 2 + else: + raise e + return batch_size + + def quantize_activations( model: AutoModelForCausalLM, quantize_config: BaseQuantizeConfig, - calibration_tokens, + calibration_tokens: torch.Tensor, ): # Replace weight quantizer with a dynamic activation quantizer observer for name, dynamic_quant_linear in model.named_modules(): @@ -271,13 +288,16 @@ def quantize_activations( del dynamic_quant_linear cleanup_memory() + # Find the maximum batch size that can be used without going OOM + max_batch_size = find_max_batch_size(model, calibration_tokens) + # Pass through calibration data to measure activation scales with torch.inference_mode(): with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: - for row_idx in range(calibration_tokens.shape[0]): - model(calibration_tokens[row_idx].reshape(1, -1)) - cleanup_memory() - pbar.update(1) + for i in range(0, calibration_tokens.shape[0], max_batch_size): + batch = calibration_tokens[i:i + max_batch_size] + model(batch.reshape(batch.shape[0], -1)) + pbar.update(batch.shape[0]) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules():