From 935dd70d298569ac7944e49bde19fabd09a75d7d Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 10 May 2024 18:16:32 -0400 Subject: [PATCH] Update --- README.md | 65 ++++-- examples/README.md | 23 +++ examples/original_quantize.py | 295 ---------------------------- quantize.py => examples/quantize.py | 0 4 files changed, 68 insertions(+), 315 deletions(-) create mode 100644 examples/README.md delete mode 100644 examples/original_quantize.py rename quantize.py => examples/quantize.py (100%) diff --git a/README.md b/README.md index ddfb7a2..6e79275 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,57 @@ # AutoFP8 -Open-source FP8 quantization project for producing compressed checkpoints for running in vLLM - see https://github.com/vllm-project/vllm/pull/4332 for implementation. +Open-source FP8 quantization library for producing compressed checkpoints for running in vLLM - see https://github.com/vllm-project/vllm/pull/4332 for details on the implementation for inference. -## How to quantize a model +## Installation -Install this repo's requirements: +Clone this repo and install it from source: ```bash -pip install -r requirements.txt +git clone https://github.com/neuralmagic/AutoFP8.git +pip install -e AutoFP8 ``` -Command to produce a `Meta-Llama-3-8B-Instruct-FP8` quantized LLM: -```bash -python quantize.py --model-id meta-llama/Meta-Llama-3-8B-Instruct --save-dir Meta-Llama-3-8B-Instruct-FP8 -``` +A stable release will be published. + +## Quickstart + +This package introduces the `AutoFP8ForCausalLM` and `BaseQuantizeConfig` objects for managing how your model will be compressed. + +Once you load your `AutoFP8ForCausalLM`, you can tokenize your data and provide it to the `model.quantize(tokenized_text)` function to calibrate+compress the model. + +Finally, you can save your quantized model in a compressed checkpoint format compatible with vLLM using `model.save_quantized("my_model_fp8")`. + +Here is a full example covering that flow: + +```python +from transformers import AutoTokenizer +from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -Example model checkpoint with FP8 static scales for activations and weights: https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-FP8 +pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" +quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" -All arguments available for `quantize.py`: +tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +examples = ["auto_fp8 is an easy-to-use model quantization library"] +examples = tokenizer(examples, return_tensors="pt").to("cuda") + +quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic") + +model = AutoFP8ForCausalLM.from_pretrained( + pretrained_model_dir, quantize_config=quantize_config +) +model.quantize(examples) +model.save_quantized(quantized_model_dir) ``` -usage: quantize.py [-h] [--model-id MODEL_ID] [--save-dir SAVE_DIR] [--activation-scheme {static,dynamic}] [--num-samples NUM_SAMPLES] [--max-seq-len MAX_SEQ_LEN] - -options: - -h, --help show this help message and exit - --model-id MODEL_ID - --save-dir SAVE_DIR - --activation-scheme {static,dynamic} - --num-samples NUM_SAMPLES - --max-seq-len MAX_SEQ_LEN + +Finally, load it into vLLM for inference! Support began in v0.4.2 (`pip install vllm>=0.4.2`). Note that hardware support for FP8 tensor cores must be available in the GPU you are using (Ada Lovelace, Hopper, and newer). + +```python +from vllm import LLM + +model = LLM("Meta-Llama-3-8B-Instruct-FP8") +# INFO 05-10 18:02:40 model_runner.py:175] Loading model weights took 8.4595 GB + +print(model.generate("Once upon a time")) +# [RequestOutput(request_id=0, prompt='Once upon a time', prompt_token_ids=[128000, 12805, 5304, 264, 892], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' there was a man who fell in love with a woman. The man was so', token_ids=[1070, 574, 264, 893, 889, 11299, 304, 3021, 449, 264, 5333, 13, 578, 893, 574, 779], cumulative_logprob=-21.314169232733548, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1715378569.478381, last_token_time=1715378569.478381, first_scheduled_time=1715378569.480648, first_token_time=1715378569.7070432, time_in_queue=0.002267122268676758, finished_time=1715378570.104807), lora_request=None)] ``` ## How to run FP8 quantized models @@ -36,7 +61,7 @@ options: Then simply pass the quantized checkpoint directly to vLLM's entrypoints! It will detect the checkpoint format using the `quantization_config` in the `config.json`. ```python from vllm import LLM -model = LLM("nm-testing/Meta-Llama-3-8B-Instruct-FP8") +model = LLM("neuralmagic/Meta-Llama-3-8B-Instruct-FP8") # INFO 05-06 10:06:23 model_runner.py:172] Loading model weights took 8.4596 GB outputs = model.generate("Once upon a time,") diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..347e1c6 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,23 @@ +## FP8 Quantization + +This folder holds the original `quantize.py` example. + +Command to produce a `Meta-Llama-3-8B-Instruct-FP8` quantized LLM: +```bash +python quantize.py --model-id meta-llama/Meta-Llama-3-8B-Instruct --save-dir Meta-Llama-3-8B-Instruct-FP8 +``` + +Example model checkpoint with FP8 static scales for activations and weights: https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-FP8 + +All arguments available for `quantize.py`: +``` +usage: quantize.py [-h] [--model-id MODEL_ID] [--save-dir SAVE_DIR] [--activation-scheme {static,dynamic}] [--num-samples NUM_SAMPLES] [--max-seq-len MAX_SEQ_LEN] + +options: + -h, --help show this help message and exit + --model-id MODEL_ID + --save-dir SAVE_DIR + --activation-scheme {static,dynamic} + --num-samples NUM_SAMPLES + --max-seq-len MAX_SEQ_LEN +``` \ No newline at end of file diff --git a/examples/original_quantize.py b/examples/original_quantize.py deleted file mode 100644 index 5b35d69..0000000 --- a/examples/original_quantize.py +++ /dev/null @@ -1,295 +0,0 @@ -import argparse -import gc -import re -from typing import Tuple - -import torch -import torch.functional as F -import transformers -import tqdm -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - - -# HACK: override the dtype_byte_size function in transformers to support float8 types -# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 -def new_dtype_byte_size(dtype): - if dtype == torch.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - -transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - - -def cleanup_memory(): - gc.collect() - torch.cuda.empty_cache() - - -def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: - """Quantize a tensor using per-tensor static scaling factor. - - Args: - tensor: The input tensor. - """ - finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax. - # Since .abs() creates a new tensor, we use aminmax to get - # the min and max first and then calculate the absmax. - if tensor.numel() == 0: - # Deal with empty tensors (triggered by empty MoE experts) - min_val, max_val = ( - torch.tensor(0.0, dtype=tensor.dtype), - torch.tensor(1.0, dtype=tensor.dtype), - ) - else: - min_val, max_val = tensor.aminmax() - amax = min_val.abs().max(max_val.abs()) - scale = finfo.max / amax.clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(torch.float8_e4m3fn) - scale = scale.float().reciprocal() - return qweight, scale - - -def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): - cuda_compute_capability = torch.cuda.get_device_capability() - if cuda_compute_capability >= (9, 0): - output, _ = torch._scaled_mm( - A, - B.t(), - out_dtype=out_dtype, - scale_a=A_scale, - scale_b=B_scale, - bias=bias, - ) - else: - output = torch.nn.functional.linear( - A.to(out_dtype) * A_scale, - B.to(out_dtype) * B_scale.to(out_dtype), - bias=bias, - ) - return output - - -class FP8StaticLinearQuantizer(torch.nn.Module): - def __init__(self, qweight, weight_scale): - super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = None - - def forward(self, x): - # Dynamically quantize - qinput, x_act_scale = per_tensor_quantize(x) - - # Update scale if needed. - if self.act_scale is None: - self.act_scale = torch.nn.Parameter(x_act_scale) - elif x_act_scale > self.act_scale: - self.act_scale = torch.nn.Parameter(x_act_scale) - - # Pass quantized to next layer so it has realistic data. - output = fp8_gemm( - A=qinput, - A_scale=self.act_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=None, - out_dtype=x.dtype, - ) - return output - - -class FP8StaticLinear(torch.nn.Module): - def __init__(self, qweight, weight_scale, act_scale=0.0): - super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False) - - def per_tensor_quantize( - self, tensor: torch.Tensor, inv_scale: float - ) -> torch.Tensor: - # Scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) - - def forward(self, x): - qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale) - output = fp8_gemm( - A=qinput, - A_scale=self.act_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=None, - out_dtype=x.dtype, - ) - return output - - -class FP8DynamicLinear(torch.nn.Module): - def __init__(self, qweight, scale): - super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) - - def forward(self, x): - qinput, x_scale = per_tensor_quantize(x) - output = fp8_gemm( - A=qinput, - A_scale=x_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=None, - out_dtype=x.dtype, - ) - return output - - -def replace_module(model, name, new_module): - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.model.get_submodule(parent_name) - else: - parent_name = "" - parent = model.model - child_name = name - setattr(parent, child_name, new_module) - - -def quantize_weights(model): - for name, linear in model.model.named_modules(): - # if "gate" in name or not isinstance(linear, torch.nn.Linear): - if not isinstance(linear, torch.nn.Linear): - continue - quant_weight, quant_scale = per_tensor_quantize(linear.weight) - quant_linear = FP8DynamicLinear(quant_weight, quant_scale) - replace_module(model, name, quant_linear) - del linear - cleanup_memory() - - -def quantize_activations(model, calibration_tokens): - # Replace layers with quantizer. - for name, dynamic_quant_linear in model.model.named_modules(): - # if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear): - if not isinstance(dynamic_quant_linear, FP8DynamicLinear): - continue - quantizer = FP8StaticLinearQuantizer( - dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale - ) - replace_module(model, name, quantizer) - del dynamic_quant_linear - cleanup_memory() - - # Calibration. - with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating") as pbar: - for row_idx in range(calibration_tokens.shape[0]): - model(calibration_tokens[row_idx].reshape(1, -1)) - torch.cuda.empty_cache() - pbar.update(1) - - # Replace quantizer with StaticLayer. - for name, quantizer in model.model.named_modules(): - # if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer): - if not isinstance(quantizer, FP8StaticLinearQuantizer): - continue - static_proj = FP8StaticLinear( - quantizer.weight, quantizer.weight_scale, quantizer.act_scale - ) - replace_module(model, name, static_proj) - del quantizer - cleanup_memory() - - -def save_quantized_model(model, activation_scheme, save_dir): - print(f"Saving the model to {save_dir}") - static_q_dict = { - "quantization_config": { - "quant_method": "fp8", - "activation_scheme": activation_scheme, - } - } - model.config.update(static_q_dict) - model.save_pretrained(save_dir) - tokenizer.save_pretrained(save_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model-id", type=str) - parser.add_argument("--save-dir", type=str) - parser.add_argument( - "--activation-scheme", type=str, default="static", choices=["static", "dynamic"] - ) - parser.add_argument("--num-samples", type=int, default=512) - parser.add_argument("--max-seq-len", type=int, default=512) - args = parser.parse_args() - - tokenizer = AutoTokenizer.from_pretrained(args.model_id) - sample_input_tokens = tokenizer.apply_chat_template( - [{"role": "user", "content": "What is your name?"}], - add_generation_prompt=True, - return_tensors="pt", - ).to("cuda") - - ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") - ds = ds.shuffle(seed=42).select(range(args.num_samples)) - ds = ds.map( - lambda batch: { - "text": tokenizer.apply_chat_template(batch["messages"], tokenize=False) - } - ) - tokenizer.pad_token_id = tokenizer.eos_token_id - calibration_tokens = tokenizer( - ds["text"], - return_tensors="pt", - truncation=True, - padding="max_length", - max_length=args.max_seq_len, - add_special_tokens=False, - ).input_ids.to("cuda") - print("Calibration tokens:", calibration_tokens.shape) - - # Load and test the model - model = AutoModelForCausalLM.from_pretrained( - args.model_id, torch_dtype="auto", device_map="auto" - ) - print("Original model graph:\n", model) - output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) - print("ORIGINAL OUTPUT:\n", tokenizer.decode(output[0]), "\n\n") - - # Quantize weights. - quantize_weights(model) - print("Weight-quantized model graph:\n", model) - output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) - print("WEIGHT QUANT OUTPUT:\n", tokenizer.decode(output[0]), "\n\n") - - if args.activation_scheme in "dynamic": - print("Exporting model with static weights and dynamic activations") - save_quantized_model(model, args.activation_scheme, args.save_dir) - else: - assert args.activation_scheme in "static" - # Quantize activations. - quantize_activations(model, calibration_tokens=calibration_tokens) - print("Weight and activation quantized model graph:\n", model) - output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) - print("ACT QUANT OUTPUT:\n", tokenizer.decode(output[0]), "\n\n") - - print("Exporting model with static weights and static activations") - save_quantized_model(model, args.activation_scheme, args.save_dir) diff --git a/quantize.py b/examples/quantize.py similarity index 100% rename from quantize.py rename to examples/quantize.py