Skip to content

Commit

Permalink
Add automatic batching
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jun 19, 2024
1 parent 2a9330c commit ffe2e88
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,27 @@ def quantize_weights(
cleanup_memory()


def find_max_batch_size(model: AutoModelForCausalLM, tokens):
batch_size = tokens.shape[0]
while batch_size > 1:
try:
with torch.inference_mode():
model(tokens[:batch_size].reshape(batch_size, -1))
return batch_size
except RuntimeError as e:
print(e)
if 'out of memory' in str(e):
cleanup_memory()
batch_size //= 2
else:
raise e
return batch_size


def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
calibration_tokens: torch.Tensor,
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -271,13 +288,16 @@ def quantize_activations(
del dynamic_quant_linear
cleanup_memory()

# Find the maximum batch size that can be used without going OOM
max_batch_size = find_max_batch_size(model, calibration_tokens)

# Pass through calibration data to measure activation scales
with torch.inference_mode():
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)
for i in range(0, calibration_tokens.shape[0], max_batch_size):
batch = calibration_tokens[i:i + max_batch_size]
model(batch.reshape(batch.shape[0], -1))
pbar.update(batch.shape[0])

# Replace dynamic quantizer observer with StaticLinear for export
for name, quantizer in model.named_modules():
Expand Down

0 comments on commit ffe2e88

Please sign in to comment.