From b1c6ad6edad2529cf5acbab05da229f2f67593e3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 17 Jun 2024 12:40:56 -0400 Subject: [PATCH] Use `torch.inference_mode()` for lower memory usage during calibration (#20) --- auto_fp8/quantize.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index e16e471..4c1b580 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -236,11 +236,12 @@ def quantize_activations( cleanup_memory() # Pass through calibration data to measure activation scales - with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: - for row_idx in range(calibration_tokens.shape[0]): - model(calibration_tokens[row_idx].reshape(1, -1)) - cleanup_memory() - pbar.update(1) + with torch.inference_mode(): + with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules():