From 3f683f8617b8baaace7bd21f6c6ed36fa3ee7f0a Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 19 Jul 2024 10:30:19 -0400 Subject: [PATCH] Add support for dynamic activation --- auto_fp8/modeling.py | 92 ++++++++++++++++++++++++++++++++++-------- tests/test_auto_fp8.py | 34 ++++++++-------- 2 files changed, 91 insertions(+), 35 deletions(-) diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 0e4e8cc..79d80b2 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -6,6 +6,12 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM from llmcompressor.transformers import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationType, + QuantizationScheme, + QuantizationStrategy, +) class BaseQuantizeConfig: @@ -64,23 +70,75 @@ def from_pretrained( return cls(model, quantize_config) def quantize(self, dataset: Optional[Dataset] = None): - assert ( - self.quantize_config.activation_scheme == "static" - ), "Dynamic isn't supported yet" - assert ( - dataset is not None - ), "Calibration tokens required for static activation quantization" - - recipe = QuantizationModifier( - targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns - ) + if self.quantize_config.activation_scheme == "dynamic": + if dataset is None: + # For dynamic activations, we don't care about calibration data + # being provided. However, we need to pass something + # TODO(mgoin): Remove once llmcompressor allows no dataset + from datasets import load_dataset + dataset = load_dataset("openai/openai_humaneval", split="test").select(range(1)) + dataset = dataset.rename_column("prompt", "text") - oneshot( - model=self.model, - dataset=dataset, - recipe=recipe, - num_calibration_samples=dataset.shape[0], - ) + FP8_W8 = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + ) + + recipe = QuantizationModifier( + config_groups={"group_0": FP8_W8}, + ignore=self.quantize_config.ignore_patterns, + ) + + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + num_calibration_samples=dataset.shape[0], + ) + elif self.quantize_config.activation_scheme == "static": + assert ( + dataset is not None + ), "Calibration tokens required for static activation quantization" + + FP8_W8A8 = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + ) + + recipe = QuantizationModifier( + config_groups={"group_0": FP8_W8A8}, + ignore=self.quantize_config.ignore_patterns, + ) + + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + num_calibration_samples=dataset.shape[0], + ) + else: + raise ValueError( + f"Unsupported activation_scheme={self.quantize_config.activation_scheme}" + ) def save_quantized(self, save_directory: str): self.save_pretrained(save_directory, save_compressed=True) @@ -89,4 +147,4 @@ def save_pretrained(self, save_directory: str, save_compressed: bool = True): self.model.save_pretrained(save_directory, save_compressed=save_compressed) tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path) tokenizer.save_pretrained(save_directory) - print(f"Saved final checkpoint to {os.path.abspath(save_directory)}") \ No newline at end of file + print(f"Saved final checkpoint to {os.path.abspath(save_directory)}") diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 6717ae1..0322c2d 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -10,30 +10,28 @@ MODELS = [ ("facebook/opt-125m", 160), - ("Qwen/Qwen2-0.5B-Instruct", 620), + # ("Qwen/Qwen2-0.5B-Instruct", 620), ] -# @pytest.mark.parametrize("model_id,target_size", MODELS) -# def test_dynamic_quantization(model_id, target_size): -# quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_dynamic_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" -# quantize_config = BaseQuantizeConfig( -# quant_method="fp8", activation_scheme="dynamic" -# ) + quantize_config = BaseQuantizeConfig( + quant_method="fp8", activation_scheme="dynamic" + ) -# model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) -# model.model.to("cpu") - -# model.quantize() -# model.save_quantized(quantized_model_dir) + model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) + model.quantize() + model.save_quantized(quantized_model_dir) -# # Measure checkpoint size and cleanup -# model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") -# shutil.rmtree(quantized_model_dir) + # Measure checkpoint size and cleanup + model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") + shutil.rmtree(quantized_model_dir) -# # We expect the quantized model to be a certain size -# target_size = target_size * (1024 * 1024) -# assert model_size < target_size + # We expect the quantized model to be a certain size + target_size = target_size * (1024 * 1024) + assert model_size < target_size @pytest.mark.parametrize("model_id,target_size", MODELS)