From 7546f765a528e18532af164f8f39455ea849e050 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 18 Jul 2024 16:55:45 -0400 Subject: [PATCH] Switch backend to use llm-compressor --- auto_fp8/__init__.py | 3 +- auto_fp8/config.py | 42 ------------ auto_fp8/modeling.py | 156 ++++++++++++++++++++----------------------- example_dataset.py | 25 ++++--- 4 files changed, 90 insertions(+), 136 deletions(-) delete mode 100644 auto_fp8/config.py diff --git a/auto_fp8/__init__.py b/auto_fp8/__init__.py index ea4fbb6..d463cc8 100644 --- a/auto_fp8/__init__.py +++ b/auto_fp8/__init__.py @@ -1,5 +1,4 @@ -from .config import BaseQuantizeConfig -from .modeling import AutoFP8ForCausalLM +from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig __all__ = [ "AutoFP8ForCausalLM", diff --git a/auto_fp8/config.py b/auto_fp8/config.py deleted file mode 100644 index 24c6200..0000000 --- a/auto_fp8/config.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Optional, Tuple - - -class BaseQuantizeConfig: - """Configuration for model quantization. - - Args: - quant_method: Type/precision of quantization method to use. - At the moment, this is just "fp8" which specifically means - the fp8_e4m3 format in pytorch. - activation_scheme: Choice of either "dynamic" or "static" quantization - of activtions. If "static", then calibration samples are required - during quantization to produce accurate per-tensor scales for - activations of Linear modules. - ignore_patterns: List of patterns used to ignore layers. If a string - starts with "re:", then everything afterwards is used as python - regex style matching i.e. re.search(), for each Linear layer. - By default, "re:.*lm_head" is included to ignore the embedding - Linear layer usually at the end of decoder LLMs - kv_cache_quant_targets: Tuple of Linear module names to target for - calibration of the output scales for KV cache quantization. - Usually, these should be `("k_proj", "v_proj")`. - """ - - def __init__( - self, - quant_method: str = "fp8", - activation_scheme: str = "static", - ignore_patterns: List[str] = ["re:.*lm_head"], - kv_cache_quant_targets: Optional[Tuple[str]] = None, - ): - if quant_method != "fp8": - raise ValueError("Only FP8 quantization is supported.") - if activation_scheme not in ["static", "dynamic"]: - raise ValueError( - "Invalid activation_scheme. Choose either 'static' or 'dynamic'." - ) - self.quant_method = quant_method - self.activation_scheme = activation_scheme - self.ignore_patterns = ignore_patterns - self.kv_cache_quant_targets = kv_cache_quant_targets - self.ignored_layers = [] diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index a5aa2b0..2a4637a 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -1,26 +1,46 @@ -import re -from typing import List, Optional, Tuple - -import torch -from transformers import AutoModelForCausalLM - -from auto_fp8.config import BaseQuantizeConfig -from auto_fp8.quantize import ( - quantize_activations, - quantize_weights, - save_quantized_model, -) - - -class AutoFP8ForCausalLM: +import os +from typing import List, Optional + +from transformers import AutoConfig, AutoTokenizer +from datasets import Dataset +from llmcompressor.transformers import SparseAutoModelForCausalLM +from llmcompressor.transformers import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +class BaseQuantizeConfig: + """Configuration for model quantization. + + Args: + quant_method: Type/precision of quantization method to use. + At the moment, this is just "fp8" which specifically means + the fp8_e4m3 format in pytorch. + activation_scheme: Choice of either "dynamic" or "static" quantization + of activtions. If "static", then calibration samples are required + during quantization to produce accurate per-tensor scales for + activations of Linear modules. + ignore_patterns: List of patterns used to ignore layers. If a string + starts with "re:", then everything afterwards is used as python + regex style matching i.e. re.search(), for each Linear layer. + By default, "lm_head" is included to ignore the embedding + Linear layer usually at the end of decoder LLMs + """ def __init__( self, - model: AutoModelForCausalLM, - quantize_config: BaseQuantizeConfig, + quant_method: str = "fp8", + activation_scheme: str = "static", + ignore_patterns: List[str] = ["lm_head"], ): + self.quant_method = quant_method + self.activation_scheme = activation_scheme + self.ignore_patterns = ignore_patterns + + +class AutoFP8ForCausalLM: + def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig): self.model = model self.model_type = self.model.config.model_type self.config = self.model.config +<<<<<<< HEAD # Gather the Linear module names that we want to ignore quantize_config.ignored_layers = get_layers_to_ignore( @@ -45,76 +65,23 @@ def __init__( ) quantize_config.kv_cache_quant_layers = kv_cache_quant_layers +======= +>>>>>>> ba7d420 (Switch backend to use llm-compressor) self.quantize_config = quantize_config @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - quantize_config: BaseQuantizeConfig, - **model_init_kwargs, - ): - """Load the un-quantized pretrained model""" - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - - # Parameters related to loading from Hugging Face Hub - cache_dir = model_init_kwargs.pop("cache_dir", None) - force_download = model_init_kwargs.pop("force_download", False) - resume_download = model_init_kwargs.pop("resume_download", False) - proxies = model_init_kwargs.pop("proxies", None) - local_files_only = model_init_kwargs.pop("local_files_only", False) - use_auth_token = model_init_kwargs.pop("use_auth_token", None) - revision = model_init_kwargs.pop("revision", None) - subfolder = model_init_kwargs.pop("subfolder", "") - commit_hash = model_init_kwargs.pop("_commit_hash", None) - - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "use_auth_token": use_auth_token, - "revision": revision, - "subfolder": subfolder, - "_commit_hash": commit_hash, - } - - torch.cuda.empty_cache() - - # Important defaults - if "torch_dtype" not in model_init_kwargs: - model_init_kwargs["torch_dtype"] = "auto" - - if "device_map" not in model_init_kwargs: - model_init_kwargs["device_map"] = "auto" - - merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} - print("Loading model with the following kwargs:", merged_kwargs) - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, **merged_kwargs + def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + model = SparseAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + config=config, + device_map="auto", + torch_dtype="auto", + **kwargs ) - - model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] - if any(k in model_config for k in seq_len_keys): - for key in seq_len_keys: - if key in model_config: - model.seqlen = model_config[key] - break - else: - print("Can't get model's sequence length, setting to 2048.") - model.seqlen = 2048 - model.eval() - return cls(model, quantize_config) +<<<<<<< HEAD def quantize(self, calibration_tokens: Optional[torch.Tensor] = None): <<<<<<< HEAD <<<<<<< HEAD @@ -161,12 +128,28 @@ def save_quantized(self, save_dir): self.model, quant_config=self.quantize_config, save_dir=save_dir, +======= + def quantize(self, dataset: Optional[Dataset] = None): + assert self.quantize_config.activation_scheme == "static" + assert dataset is not None, "Calibration tokens required for static activation quantization" + + recipe = QuantizationModifier( + targets="Linear", + scheme="FP8", + ignore=self.quantize_config.ignore_patterns +>>>>>>> ba7d420 (Switch backend to use llm-compressor) ) + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + ) -def get_layers_to_ignore(model, ignore_patterns) -> List[str]: - ignored_layers = set() + def save_quantized(self, save_directory: str): + self.save_pretrained(save_directory, save_compressed=True) +<<<<<<< HEAD for name, linear in model.named_modules(): if not isinstance(linear, torch.nn.Linear): continue @@ -220,3 +203,10 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List return kv_cache_quant_layers >>>>>>> c3acdee (Switch from output_scale to kv_scale) +======= + def save_pretrained(self, save_directory: str, save_compressed: bool = True): + self.model.save_pretrained(save_directory, save_compressed=save_compressed) + tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path) + tokenizer.save_pretrained(save_directory) + print(f"Saved final checkpoint to {os.path.abspath(save_directory)}") +>>>>>>> ba7d420 (Switch backend to use llm-compressor) diff --git a/example_dataset.py b/example_dataset.py index 204345f..82d336e 100644 --- a/example_dataset.py +++ b/example_dataset.py @@ -3,20 +3,27 @@ from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" +pretrained_model_dir = "facebook/opt-125m" +quantized_model_dir = "opt-125m-FP8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") -examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] -examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") +MAX_SEQUENCE_LENGTH = 2048 +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) +def preprocess(example): + example = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer( + example, + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) +ds = ds.map(preprocess, remove_columns=ds.column_names) quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") -model = AutoFP8ForCausalLM.from_pretrained( - pretrained_model_dir, quantize_config=quantize_config -) -model.quantize(examples) +model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) +model.quantize(ds) model.save_quantized(quantized_model_dir)