diff --git a/auto_fp8/__init__.py b/auto_fp8/__init__.py index ea4fbb6..d463cc8 100644 --- a/auto_fp8/__init__.py +++ b/auto_fp8/__init__.py @@ -1,5 +1,4 @@ -from .config import BaseQuantizeConfig -from .modeling import AutoFP8ForCausalLM +from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig __all__ = [ "AutoFP8ForCausalLM", diff --git a/auto_fp8/config.py b/auto_fp8/config.py deleted file mode 100644 index 24c6200..0000000 --- a/auto_fp8/config.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Optional, Tuple - - -class BaseQuantizeConfig: - """Configuration for model quantization. - - Args: - quant_method: Type/precision of quantization method to use. - At the moment, this is just "fp8" which specifically means - the fp8_e4m3 format in pytorch. - activation_scheme: Choice of either "dynamic" or "static" quantization - of activtions. If "static", then calibration samples are required - during quantization to produce accurate per-tensor scales for - activations of Linear modules. - ignore_patterns: List of patterns used to ignore layers. If a string - starts with "re:", then everything afterwards is used as python - regex style matching i.e. re.search(), for each Linear layer. - By default, "re:.*lm_head" is included to ignore the embedding - Linear layer usually at the end of decoder LLMs - kv_cache_quant_targets: Tuple of Linear module names to target for - calibration of the output scales for KV cache quantization. - Usually, these should be `("k_proj", "v_proj")`. - """ - - def __init__( - self, - quant_method: str = "fp8", - activation_scheme: str = "static", - ignore_patterns: List[str] = ["re:.*lm_head"], - kv_cache_quant_targets: Optional[Tuple[str]] = None, - ): - if quant_method != "fp8": - raise ValueError("Only FP8 quantization is supported.") - if activation_scheme not in ["static", "dynamic"]: - raise ValueError( - "Invalid activation_scheme. Choose either 'static' or 'dynamic'." - ) - self.quant_method = quant_method - self.activation_scheme = activation_scheme - self.ignore_patterns = ignore_patterns - self.kv_cache_quant_targets = kv_cache_quant_targets - self.ignored_layers = [] diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 04a9e71..79d80b2 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -1,42 +1,55 @@ -import re -from typing import List, Optional, Tuple +import os +from typing import List, Optional + +from transformers import AutoConfig, AutoTokenizer +from datasets import Dataset +from llmcompressor.transformers import SparseAutoModelForCausalLM +from llmcompressor.transformers import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationType, + QuantizationScheme, + QuantizationStrategy, +) -import torch -from transformers import AutoModelForCausalLM -from auto_fp8.config import BaseQuantizeConfig -from auto_fp8.quantize import ( - quantize_activations, - quantize_weights, - save_quantized_model, -) +class BaseQuantizeConfig: + """Configuration for model quantization. + + Args: + quant_method: Type/precision of quantization method to use. + At the moment, this is just "fp8" which specifically means + the fp8_e4m3 format in pytorch. + activation_scheme: Choice of either "dynamic" or "static" quantization + of activtions. If "static", then calibration samples are required + during quantization to produce accurate per-tensor scales for + activations of Linear modules. + ignore_patterns: List of patterns used to ignore layers. If a string + starts with "re:", then everything afterwards is used as python + regex style matching i.e. re.search(), for each Linear layer. + By default, "lm_head" is included to ignore the embedding + Linear layer usually at the end of decoder LLMs + """ + + def __init__( + self, + quant_method: str = "fp8", + activation_scheme: str = "static", + ignore_patterns: List[str] = ["lm_head"], + ): + self.quant_method = quant_method + self.activation_scheme = activation_scheme + self.ignore_patterns = ignore_patterns class AutoFP8ForCausalLM: def __init__( - self, - model: AutoModelForCausalLM, - quantize_config: BaseQuantizeConfig, + self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig ): self.model = model self.model_type = self.model.config.model_type self.config = self.model.config - - # Gather the Linear module names that we want to ignore - quantize_config.ignored_layers = get_layers_to_ignore( - self.model, quantize_config.ignore_patterns - ) - - if quantize_config.kv_cache_quant_targets: - kv_cache_quant_layers = get_kv_cache_quant_layers( - self.model, quantize_config.kv_cache_quant_targets - ) - if len(kv_cache_quant_layers) == 0: - raise ValueError( - f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument." - ) - quantize_config.kv_cache_quant_layers = kv_cache_quant_layers - self.quantize_config = quantize_config @classmethod @@ -44,130 +57,94 @@ def from_pretrained( cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, - **model_init_kwargs, + **kwargs, ): - """Load the un-quantized pretrained model""" - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - - # Parameters related to loading from Hugging Face Hub - cache_dir = model_init_kwargs.pop("cache_dir", None) - force_download = model_init_kwargs.pop("force_download", False) - resume_download = model_init_kwargs.pop("resume_download", False) - proxies = model_init_kwargs.pop("proxies", None) - local_files_only = model_init_kwargs.pop("local_files_only", False) - use_auth_token = model_init_kwargs.pop("use_auth_token", None) - revision = model_init_kwargs.pop("revision", None) - subfolder = model_init_kwargs.pop("subfolder", "") - commit_hash = model_init_kwargs.pop("_commit_hash", None) - - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "use_auth_token": use_auth_token, - "revision": revision, - "subfolder": subfolder, - "_commit_hash": commit_hash, - } - - torch.cuda.empty_cache() - - # Important defaults - if "torch_dtype" not in model_init_kwargs: - model_init_kwargs["torch_dtype"] = "auto" - - if "device_map" not in model_init_kwargs: - model_init_kwargs["device_map"] = "auto" - - merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} - print("Loading model with the following kwargs:", merged_kwargs) - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, **merged_kwargs + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + model = SparseAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + config=config, + device_map="auto", + torch_dtype="auto", + **kwargs, ) - - model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] - if any(k in model_config for k in seq_len_keys): - for key in seq_len_keys: - if key in model_config: - model.seqlen = model_config[key] - break - else: - print("Can't get model's sequence length, setting to 2048.") - model.seqlen = 2048 - model.eval() - return cls(model, quantize_config) - def quantize(self, calibration_tokens: Optional[torch.Tensor] = None): + def quantize(self, dataset: Optional[Dataset] = None): + if self.quantize_config.activation_scheme == "dynamic": + if dataset is None: + # For dynamic activations, we don't care about calibration data + # being provided. However, we need to pass something + # TODO(mgoin): Remove once llmcompressor allows no dataset + from datasets import load_dataset + dataset = load_dataset("openai/openai_humaneval", split="test").select(range(1)) + dataset = dataset.rename_column("prompt", "text") + + FP8_W8 = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + ) - # Always quantize the weights as they do not require calibration data - quantize_weights(self.model, self.quantize_config) + recipe = QuantizationModifier( + config_groups={"group_0": FP8_W8}, + ignore=self.quantize_config.ignore_patterns, + ) - if self.quantize_config.activation_scheme == "static": + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + num_calibration_samples=dataset.shape[0], + ) + elif self.quantize_config.activation_scheme == "static": assert ( - calibration_tokens is not None - ), "Calibration tokens required for activation quantization" - - - def _prepare_calibration_data(calibration_tokens): - if hasattr(calibration_tokens, "input_ids"): - return calibration_tokens.input_ids - return calibration_tokens - - quantize_activations( - self.model, - self.quantize_config, - _prepare_calibration_data(calibration_tokens), + dataset is not None + ), "Calibration tokens required for static activation quantization" + + FP8_W8A8 = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), ) - def save_quantized(self, save_dir): - save_quantized_model( - self.model, - quant_config=self.quantize_config, - save_dir=save_dir, - ) - - -def get_layers_to_ignore(model, ignore_patterns) -> List[str]: - ignored_layers = set() - - for name, linear in model.named_modules(): - if not isinstance(linear, torch.nn.Linear): - continue - - for ignore_pattern in ignore_patterns: - regex_prefix = "re:" - if ignore_pattern.startswith(regex_prefix): - # check if name matches regex and add to set if true - regex_pattern = ignore_pattern[len(regex_prefix) :] - if re.search(regex_pattern, name): - ignored_layers.add(name) - else: - # else, exact match - if ignore_pattern == name: - ignored_layers.add(name) - - return list(ignored_layers) - - -def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]: - kv_cache_quant_layers = [] + recipe = QuantizationModifier( + config_groups={"group_0": FP8_W8A8}, + ignore=self.quantize_config.ignore_patterns, + ) - for name, linear in model.named_modules(): - if not isinstance(linear, torch.nn.Linear): - continue + oneshot( + model=self.model, + dataset=dataset, + recipe=recipe, + num_calibration_samples=dataset.shape[0], + ) + else: + raise ValueError( + f"Unsupported activation_scheme={self.quantize_config.activation_scheme}" + ) - for output_quant_target in kv_cache_quant_targets: - if name.endswith(output_quant_target): - kv_cache_quant_layers.append(name) + def save_quantized(self, save_directory: str): + self.save_pretrained(save_directory, save_compressed=True) - return kv_cache_quant_layers + def save_pretrained(self, save_directory: str, save_compressed: bool = True): + self.model.save_pretrained(save_directory, save_compressed=save_compressed) + tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path) + tokenizer.save_pretrained(save_directory) + print(f"Saved final checkpoint to {os.path.abspath(save_directory)}") diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py deleted file mode 100644 index 38a4de6..0000000 --- a/auto_fp8/quantize.py +++ /dev/null @@ -1,344 +0,0 @@ -import gc -import re -from typing import Optional, Tuple -import copy - -import torch -import tqdm -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer - -from .config import BaseQuantizeConfig - - -# HACK: Override the dtype_byte_size function in transformers to support float8 types -# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 -def new_dtype_byte_size(dtype): - if dtype == torch.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - -transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - - -def cleanup_memory(): - gc.collect() - torch.cuda.empty_cache() - - -def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: - """Quantize a tensor using per-tensor static scaling factor. - Args: - tensor: The input tensor. - """ - finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax. - # Since .abs() creates a new tensor, we use aminmax to get - # the min and max first and then calculate the absmax. - if tensor.numel() == 0: - # Deal with empty tensors (triggered by empty MoE experts) - min_val, max_val = ( - torch.tensor(-16.0, dtype=tensor.dtype), - torch.tensor(16.0, dtype=tensor.dtype), - ) - else: - min_val, max_val = tensor.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()) - scale = finfo.max / amax.clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(torch.float8_e4m3fn) - scale = scale.float().reciprocal() - return qweight, scale - - -def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) - - -def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): - if A.numel() == 0: - # Deal with empty tensors (triggeted by empty MoE experts) - return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) - - # TODO: Disable native fp8 gemm for now, always just dequantize - # native_fp8_support = ( - # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - # ) - native_fp8_support = False - if native_fp8_support: - need_reshape = A.dim() == 3 - if need_reshape: - batch_size = A.shape[0] - A_input = A.reshape(-1, A.shape[-1]) - else: - batch_size = None - A_input = A - output, _ = torch._scaled_mm( - A_input, - B.t(), - out_dtype=out_dtype, - scale_a=A_scale, - scale_b=B_scale, - bias=bias, - ) - if need_reshape: - output = output.reshape( - batch_size, output.shape[0] // batch_size, output.shape[1] - ) - else: - output = torch.nn.functional.linear( - A.to(out_dtype) * A_scale, - B.to(out_dtype) * B_scale.to(out_dtype), - bias=bias, - ) - return output - - -# Class responsible for quantizing weights -class FP8DynamicLinear(torch.nn.Module): - def __init__( - self, - weight: torch.Tensor, - weight_scale: torch.Tensor, - bias: torch.nn.Parameter, - ): - super().__init__() - self.weight = torch.nn.Parameter(weight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.bias = bias - - def forward(self, x): - qinput, x_scale = per_tensor_quantize(x) - output = fp8_gemm( - A=qinput, - A_scale=x_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=self.bias, - out_dtype=x.dtype, - ) - return output - - -# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer -class FP8StaticLinearQuantizer(torch.nn.Module): - def __init__( - self, - weight: torch.Tensor, - weight_scale: torch.Tensor, - bias: torch.nn.Parameter, - quantize_output: bool = False, - ): - super().__init__() - self.weight = torch.nn.Parameter(weight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.bias = bias - self.input_scale = None - self.output_scale = None - self.quantize_output = quantize_output - - def forward(self, x): - qinput, x_input_scale = per_tensor_quantize(x) - if self.input_scale is None: - self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) - elif x_input_scale > self.input_scale: - self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) - output = fp8_gemm( - A=qinput, - A_scale=self.input_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=self.bias, - out_dtype=x.dtype, - ) - - # Optionally, quantize output and record scale - if self.quantize_output: - qoutput, output_scale = per_tensor_quantize(output) - if self.output_scale is None: - self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) - elif output_scale > self.output_scale: - self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) - output = qoutput.to(output.dtype) * output_scale - - return output - - -# Module responsible for representing the final checkpoint representation -class FP8StaticLinear(torch.nn.Module): - def __init__( - self, - weight: torch.nn.Parameter, - weight_scale: torch.nn.Parameter, - bias: torch.nn.Parameter, - input_scale: torch.nn.Parameter, - output_scale: Optional[torch.nn.Parameter] = None, - ): - super().__init__() - self.weight = weight - self.weight_scale = weight_scale - self.bias = bias - self.input_scale = input_scale - self.output_scale = output_scale - - def forward(self, x): - qinput = static_per_tensor_quantize(x, self.input_scale) - output = fp8_gemm( - A=qinput, - A_scale=self.input_scale, - B=self.weight, - B_scale=self.weight_scale, - bias=self.bias, - out_dtype=x.dtype, - ) - - if self.output_scale: - qoutput = static_per_tensor_quantize(output, self.output_scale) - output = qoutput.to(output.dtype) * self.output_scale - - return output - - -def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.Module): - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - setattr(parent, child_name, new_module) - - -def quantize_weights( - model: AutoModelForCausalLM, - quantize_config: BaseQuantizeConfig, -): - named_modules = list(model.named_modules()) - for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"): - if ( - not isinstance(linear, torch.nn.Linear) - or name in quantize_config.ignored_layers - ): - continue - quant_weight, weight_scale = per_tensor_quantize(linear.weight) - bias = copy.deepcopy(linear.bias) if linear.bias is not None else None - quant_linear = FP8DynamicLinear( - weight=quant_weight, weight_scale=weight_scale, bias=bias - ) - replace_module(model, name, quant_linear) - del linear.weight - del linear.bias - del linear - cleanup_memory() - - -def quantize_activations( - model: AutoModelForCausalLM, - quantize_config: BaseQuantizeConfig, - calibration_tokens, -): - # Replace weight quantizer with a dynamic activation quantizer observer - for name, dynamic_quant_linear in model.named_modules(): - if ( - not isinstance(dynamic_quant_linear, FP8DynamicLinear) - or name in quantize_config.ignored_layers - ): - continue - quantizer = FP8StaticLinearQuantizer( - weight=dynamic_quant_linear.weight, - weight_scale=dynamic_quant_linear.weight_scale, - bias=dynamic_quant_linear.bias, - quantize_output=( - hasattr(quantize_config, "kv_cache_quant_layers") - and name in quantize_config.kv_cache_quant_layers - ), - ) - replace_module(model, name, quantizer) - del dynamic_quant_linear - cleanup_memory() - - # Pass through calibration data to measure activation scales - with torch.inference_mode(): - with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: - for row_idx in range(calibration_tokens.shape[0]): - model(calibration_tokens[row_idx].reshape(1, -1)) - cleanup_memory() - pbar.update(1) - - # Replace dynamic quantizer observer with StaticLinear for export - for name, quantizer in model.named_modules(): - if ( - not isinstance(quantizer, FP8StaticLinearQuantizer) - or name in quantize_config.ignored_layers - ): - continue - static_proj = FP8StaticLinear( - weight=quantizer.weight, - weight_scale=quantizer.weight_scale, - bias=quantizer.bias, - input_scale=quantizer.input_scale, - output_scale=quantizer.output_scale, - ) - replace_module(model, name, static_proj) - del quantizer - cleanup_memory() - - # Post-process step for kv cache scales to take the k/v module - # `output_scale` parameters, take the max of them, and store them in - # the parent attention module as `kv_scale` - # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block - if hasattr(quantize_config, "kv_cache_quant_layers"): - # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...] - # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...] - kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2) - for k_proj_name, v_proj_name in kv_proj_pairs: - parent_module_name = ".".join(k_proj_name.split(".")[:-1]) - assert parent_module_name == ".".join(v_proj_name.split(".")[:-1]) - parent_module = dict(model.named_modules())[parent_module_name] - - k_proj = dict(model.named_modules())[k_proj_name] - v_proj = dict(model.named_modules())[v_proj_name] - - kv_scale = max(k_proj.output_scale, v_proj.output_scale) - parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False) - - # Remove output_scale from k_proj and v_proj - k_proj.output_scale = None - v_proj.output_scale = None - cleanup_memory() - - -def save_quantized_model( - model: AutoModelForCausalLM, - quant_config: BaseQuantizeConfig, - save_dir: str, -): - print(model) - print(f"Saving the model to {save_dir}") - static_q_dict = { - "quantization_config": { - "quant_method": "fp8", - "activation_scheme": quant_config.activation_scheme, - "ignored_layers": quant_config.ignored_layers, - } - } - if hasattr(quant_config, "kv_cache_quant_layers"): - static_q_dict["quantization_config"]["kv_cache_scheme"] = "static" - model.config.update(static_q_dict) - model.save_pretrained(save_dir) - tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) - tokenizer.save_pretrained(save_dir) diff --git a/example_dataset.py b/example_dataset.py index 204345f..bf6b6fd 100644 --- a/example_dataset.py +++ b/example_dataset.py @@ -3,20 +3,20 @@ from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" +pretrained_model_dir = "facebook/opt-125m" +quantized_model_dir = "opt-125m-FP8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") -examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] -examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) +def preprocess(example): + example = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer(example, max_length=2048, truncation=True, add_special_tokens=False) +ds = ds.map(preprocess, remove_columns=ds.column_names) quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") -model = AutoFP8ForCausalLM.from_pretrained( - pretrained_model_dir, quantize_config=quantize_config -) -model.quantize(examples) +model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) +model.quantize(ds) model.save_quantized(quantized_model_dir) diff --git a/requirements.txt b/requirements.txt index f40dfeb..191d853 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ transformers datasets accelerate tqdm +llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git diff --git a/setup.py b/setup.py index 7417754..3dcd85f 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ "datasets", "accelerate", "tqdm", + "llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git" ], classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 6045d84..0322c2d 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -3,13 +3,14 @@ import pytest import safetensors.torch +from datasets import load_dataset from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig MODELS = [ ("facebook/opt-125m", 160), - ("Qwen/Qwen2-0.5B-Instruct", 620), + # ("Qwen/Qwen2-0.5B-Instruct", 620), ] @pytest.mark.parametrize("model_id,target_size", MODELS) @@ -21,8 +22,6 @@ def test_dynamic_quantization(model_id, target_size): ) model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) - model.model.to("cpu") - model.quantize() model.save_quantized(quantized_model_dir) @@ -40,15 +39,22 @@ def test_static_quantization(model_id, target_size): quantized_model_dir = model_id.split("/")[-1] + "-fp8-static" tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - examples = ["auto-fp8 is an easy-to-use model quantization library"] - examples = tokenizer(examples, return_tensors="pt") + ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(2)) + def preprocess(example): + example = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer( + example, + padding=False, + max_length=32, + truncation=True, + add_special_tokens=False, + ) + ds = ds.map(preprocess, remove_columns=ds.column_names) quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) - model.model.to("cpu") - - model.quantize(examples) + model.quantize(ds) model.save_quantized(quantized_model_dir) # Measure checkpoint size and cleanup @@ -59,40 +65,40 @@ def test_static_quantization(model_id, target_size): target_size = target_size * (1024 * 1024) assert model_size < target_size -@pytest.mark.parametrize("model_id,target_size", MODELS) -def test_kv_cache_static_quantization(model_id, target_size): - quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv" - - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - examples = ["auto-fp8 is an easy-to-use model quantization library"] - examples = tokenizer(examples, return_tensors="pt") - - quantize_config = BaseQuantizeConfig( - quant_method="fp8", - activation_scheme="static", - kv_cache_quant_targets=("k_proj", "v_proj"), - ) - - model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) - model.model.to("cpu") - - model.quantize(examples) - model.save_quantized(quantized_model_dir) - - tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") - proj_linear_count = 0 - kv_scale_count = 0 - for name, _ in tensors.items(): - if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): - proj_linear_count += 1 - if name.endswith("kv_scale"): - kv_scale_count += 1 - assert proj_linear_count // 2 == kv_scale_count - - # Measure checkpoint size and cleanup - model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") - shutil.rmtree(quantized_model_dir) - - # We expect the quantized model to be a certain size - target_size = target_size * (1024 * 1024) - assert model_size < target_size +# @pytest.mark.parametrize("model_id,target_size", MODELS) +# def test_kv_cache_static_quantization(model_id, target_size): +# quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv" + +# tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) +# examples = ["auto-fp8 is an easy-to-use model quantization library"] +# examples = tokenizer(examples, return_tensors="pt") + +# quantize_config = BaseQuantizeConfig( +# quant_method="fp8", +# activation_scheme="static", +# kv_cache_quant_targets=("k_proj", "v_proj"), +# ) + +# model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) +# model.model.to("cpu") + +# model.quantize(examples) +# model.save_quantized(quantized_model_dir) + +# tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") +# proj_linear_count = 0 +# kv_scale_count = 0 +# for name, _ in tensors.items(): +# if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): +# proj_linear_count += 1 +# if name.endswith("kv_scale"): +# kv_scale_count += 1 +# assert proj_linear_count // 2 == kv_scale_count + +# # Measure checkpoint size and cleanup +# model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") +# shutil.rmtree(quantized_model_dir) + +# # We expect the quantized model to be a certain size +# target_size = target_size * (1024 * 1024) +# assert model_size < target_size \ No newline at end of file