Skip to content

Commit

Permalink
Switch backend to use llm-compressor
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 18, 2024
1 parent 0249168 commit 7546f76
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 136 deletions.
3 changes: 1 addition & 2 deletions auto_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .config import BaseQuantizeConfig
from .modeling import AutoFP8ForCausalLM
from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig

__all__ = [
"AutoFP8ForCausalLM",
Expand Down
42 changes: 0 additions & 42 deletions auto_fp8/config.py

This file was deleted.

156 changes: 73 additions & 83 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
import re
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM

from auto_fp8.config import BaseQuantizeConfig
from auto_fp8.quantize import (
quantize_activations,
quantize_weights,
save_quantized_model,
)


class AutoFP8ForCausalLM:
import os
from typing import List, Optional

from transformers import AutoConfig, AutoTokenizer
from datasets import Dataset
from llmcompressor.transformers import SparseAutoModelForCausalLM
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

class BaseQuantizeConfig:
"""Configuration for model quantization.
Args:
quant_method: Type/precision of quantization method to use.
At the moment, this is just "fp8" which specifically means
the fp8_e4m3 format in pytorch.
activation_scheme: Choice of either "dynamic" or "static" quantization
of activtions. If "static", then calibration samples are required
during quantization to produce accurate per-tensor scales for
activations of Linear modules.
ignore_patterns: List of patterns used to ignore layers. If a string
starts with "re:", then everything afterwards is used as python
regex style matching i.e. re.search(), for each Linear layer.
By default, "lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
"""
def __init__(
self,
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = ["lm_head"],
):
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns


class AutoFP8ForCausalLM:
def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig):
self.model = model
self.model_type = self.model.config.model_type
self.config = self.model.config
<<<<<<< HEAD

# Gather the Linear module names that we want to ignore
quantize_config.ignored_layers = get_layers_to_ignore(
Expand All @@ -45,76 +65,23 @@ def __init__(
)
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers

=======
>>>>>>> ba7d420 (Switch backend to use llm-compressor)
self.quantize_config = quantize_config

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
**model_init_kwargs,
):
"""Load the un-quantized pretrained model"""

def skip(*args, **kwargs):
pass

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs.pop("cache_dir", None)
force_download = model_init_kwargs.pop("force_download", False)
resume_download = model_init_kwargs.pop("resume_download", False)
proxies = model_init_kwargs.pop("proxies", None)
local_files_only = model_init_kwargs.pop("local_files_only", False)
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
revision = model_init_kwargs.pop("revision", None)
subfolder = model_init_kwargs.pop("subfolder", "")
commit_hash = model_init_kwargs.pop("_commit_hash", None)

cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_commit_hash": commit_hash,
}

torch.cuda.empty_cache()

# Important defaults
if "torch_dtype" not in model_init_kwargs:
model_init_kwargs["torch_dtype"] = "auto"

if "device_map" not in model_init_kwargs:
model_init_kwargs["device_map"] = "auto"

merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
print("Loading model with the following kwargs:", merged_kwargs)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **merged_kwargs
def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
model = SparseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
config=config,
device_map="auto",
torch_dtype="auto",
**kwargs
)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any(k in model_config for k in seq_len_keys):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
print("Can't get model's sequence length, setting to 2048.")
model.seqlen = 2048
model.eval()

return cls(model, quantize_config)

<<<<<<< HEAD
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
<<<<<<< HEAD
<<<<<<< HEAD
Expand Down Expand Up @@ -161,12 +128,28 @@ def save_quantized(self, save_dir):
self.model,
quant_config=self.quantize_config,
save_dir=save_dir,
=======
def quantize(self, dataset: Optional[Dataset] = None):
assert self.quantize_config.activation_scheme == "static"
assert dataset is not None, "Calibration tokens required for static activation quantization"

recipe = QuantizationModifier(
targets="Linear",
scheme="FP8",
ignore=self.quantize_config.ignore_patterns
>>>>>>> ba7d420 (Switch backend to use llm-compressor)
)

oneshot(
model=self.model,
dataset=dataset,
recipe=recipe,
)

def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()
def save_quantized(self, save_directory: str):
self.save_pretrained(save_directory, save_compressed=True)

<<<<<<< HEAD
for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue
Expand Down Expand Up @@ -220,3 +203,10 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List

return kv_cache_quant_layers
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
=======
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
tokenizer.save_pretrained(save_directory)
print(f"Saved final checkpoint to {os.path.abspath(save_directory)}")
>>>>>>> ba7d420 (Switch backend to use llm-compressor)
25 changes: 16 additions & 9 deletions example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
MAX_SEQUENCE_LENGTH = 2048
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
def preprocess(example):
example = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
example,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(preprocess, remove_columns=ds.column_names)

quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(ds)
model.save_quantized(quantized_model_dir)

0 comments on commit 7546f76

Please sign in to comment.