From 3b16e0868801da6fa1741c333c4491f4194fb179 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 10 May 2024 17:55:08 -0400 Subject: [PATCH] Cleanup --- .github/workflows/test.yaml | 2 +- auto_fp8/quantize.py | 9 +++++++-- example.py | 13 +++++-------- examples/original_quantize.py | 2 +- quantize.py | 11 ++++++++--- setup.py | 2 +- tests/test_auto_fp8.py | 25 ++++++++++++++----------- 7 files changed, 37 insertions(+), 27 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0a7756a..3ecc403 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ on: - main jobs: - ruff: + test: runs-on: ubuntu-latest strategy: matrix: diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index f12a968..a450196 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -179,7 +179,9 @@ def quantize_activations(model, calibration_tokens): if not isinstance(dynamic_quant_linear, FP8DynamicLinear): continue quantizer = FP8StaticLinearQuantizer( - dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale, dynamic_quant_linear.bias + dynamic_quant_linear.weight, + dynamic_quant_linear.weight_scale, + dynamic_quant_linear.bias, ) replace_module(model, name, quantizer) del dynamic_quant_linear @@ -197,7 +199,10 @@ def quantize_activations(model, calibration_tokens): if not isinstance(quantizer, FP8StaticLinearQuantizer): continue static_proj = FP8StaticLinear( - quantizer.weight, quantizer.weight_scale, quantizer.bias, quantizer.act_scale + quantizer.weight, + quantizer.weight_scale, + quantizer.bias, + quantizer.act_scale, ) replace_module(model, name, static_proj) del quantizer diff --git a/example.py b/example.py index dc98aec..8541a6b 100644 --- a/example.py +++ b/example.py @@ -5,16 +5,13 @@ quantized_model_dir = "opt-125m-fp8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -examples = [ - "auto-fp8 is an easy-to-use model quantization library" -] +examples = ["auto-fp8 is an easy-to-use model quantization library"] examples = tokenizer(examples, return_tensors="pt").to("cuda") -quantize_config = BaseQuantizeConfig( - quant_method="fp8", - activation_scheme="static" -) +quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") -model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config=quantize_config) +model = AutoFP8ForCausalLM.from_pretrained( + pretrained_model_dir, quantize_config=quantize_config +) model.quantize(examples) model.save_quantized(quantized_model_dir) diff --git a/examples/original_quantize.py b/examples/original_quantize.py index cbd61e3..5b35d69 100644 --- a/examples/original_quantize.py +++ b/examples/original_quantize.py @@ -163,7 +163,7 @@ def forward(self, x): def replace_module(model, name, new_module): if "." in name: parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1:] + child_name = name[len(parent_name) + 1 :] parent = model.model.get_submodule(parent_name) else: parent_name = "" diff --git a/quantize.py b/quantize.py index 4d16931..7aacff2 100644 --- a/quantize.py +++ b/quantize.py @@ -165,7 +165,7 @@ def forward(self, x): def replace_module(model, name, new_module): if "." in name: parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1:] + child_name = name[len(parent_name) + 1 :] parent = model.model.get_submodule(parent_name) else: parent_name = "" @@ -193,7 +193,9 @@ def quantize_activations(model, calibration_tokens): if not isinstance(dynamic_quant_linear, FP8DynamicLinear): continue quantizer = FP8StaticLinearQuantizer( - dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale, dynamic_quant_linear.bias + dynamic_quant_linear.weight, + dynamic_quant_linear.weight_scale, + dynamic_quant_linear.bias, ) replace_module(model, name, quantizer) del dynamic_quant_linear @@ -212,7 +214,10 @@ def quantize_activations(model, calibration_tokens): if not isinstance(quantizer, FP8StaticLinearQuantizer): continue static_proj = FP8StaticLinear( - quantizer.weight, quantizer.weight_scale, quantizer.bias, quantizer.act_scale + quantizer.weight, + quantizer.weight_scale, + quantizer.bias, + quantizer.act_scale, ) replace_module(model, name, static_proj) del quantizer diff --git a/setup.py b/setup.py index 6789d7f..1f78329 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ description="FP8 quantization for Transformers.", long_description=open("README.md").read(), long_description_content_type="text/markdown", - url="https://github.com/neuralmagic/auto_fp8", + url="https://github.com/neuralmagic/AutoFP8", packages=find_packages(), install_requires=[ "torch>=2.2", diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index e549993..4a9ae12 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -1,28 +1,31 @@ import os from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig +import shutil + def test_quantization(): model_id = "facebook/opt-125m" quantized_model_dir = "opt-125m-fp8" tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - examples = [ - "auto-fp8 is an easy-to-use model quantization library" - ] - examples = tokenizer(examples, return_tensors="pt").to("cuda") + examples = ["auto-fp8 is an easy-to-use model quantization library"] + examples = tokenizer(examples, return_tensors="pt") - quantize_config = BaseQuantizeConfig( - quant_method="fp8", activation_scheme="static" - ) + quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") - model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config=quantize_config, device_map="auto") + model = AutoFP8ForCausalLM.from_pretrained( + model_id, quantize_config=quantize_config + ) + model.model.to("cpu") model.quantize(examples) model.save_quantized(quantized_model_dir) - # We expect the model to be < 160MB + # Measure checkpoint size and cleanup model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") - target_size = 160 * (1024*1024) - assert model_size < target_size + shutil.rmtree(quantized_model_dir) + # We expect the model to be < 160MB + target_size = 160 * (1024 * 1024) + assert model_size < target_size