Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_patterns arg for ignoring layers #7

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .modeling import AutoFP8ForCausalLM
from .config import BaseQuantizeConfig
from .modeling import AutoFP8ForCausalLM

__all__ = [
"AutoFP8ForCausalLM",
Expand Down
12 changes: 11 additions & 1 deletion auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import List


class BaseQuantizeConfig:
def __init__(self, quant_method="fp8", activation_scheme="static"):
def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = [],
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
if activation_scheme not in ["static", "dynamic"]:
Expand All @@ -8,3 +16,5 @@ def __init__(self, quant_method="fp8", activation_scheme="static"):
)
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.ignored_layers = []
57 changes: 49 additions & 8 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import re
from typing import List

import torch
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers import AutoModelForCausalLM

from auto_fp8.config import BaseQuantizeConfig
from auto_fp8.quantize import (
quantize_weights,
quantize_activations,
quantize_weights,
save_quantized_model,
)
from auto_fp8.config import BaseQuantizeConfig


class AutoFP8ForCausalLM:
def __init__(
self,
model: PreTrainedModel,
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
):
self.model = model
self.model_type = self.model.config.model_type
self.quantize_config = quantize_config
self.config = self.model.config

# Gather the Linear module names that we want to ignore
quantize_config.ignored_layers = get_layers_to_ignore(
self.model, quantize_config.ignore_patterns
)

self.quantize_config = quantize_config

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -94,16 +104,47 @@ def _prepare_calibration_data(calibration_tokens):
return calibration_tokens

# Always quantize the weights as they do not require calibration data
quantize_weights(self.model)
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
quantize_activations(
self.model, _prepare_calibration_data(calibration_tokens)
self.model,
self.quantize_config,
_prepare_calibration_data(calibration_tokens),
)

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
activation_scheme=self.quantize_config.activation_scheme,
quant_config=self.quantize_config,
save_dir=save_dir,
)


def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for ignore_pattern in ignore_patterns:
regex_prefix = "re:"
if ignore_pattern.startswith(regex_prefix):
# check if name matches regex and add to set if true
regex_pattern = ignore_pattern[len(regex_prefix) :]
if re.search(regex_pattern, name):
ignored_layers.add(name)
else:
# else, exact match
if ignore_pattern == name:
ignored_layers.add(name)

return list(ignored_layers)
77 changes: 56 additions & 21 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import gc
import re
from typing import Tuple
from typing import List, Tuple

import torch
import transformers
import tqdm
from transformers import AutoTokenizer
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from .config import BaseQuantizeConfig


# HACK: Override the dtype_byte_size function in transformers to support float8 types
Expand Down Expand Up @@ -39,8 +42,8 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
if tensor.numel() == 0:
# Deal with empty tensors (triggered by empty MoE experts)
min_val, max_val = (
torch.tensor(0.0, dtype=tensor.dtype),
torch.tensor(1.0, dtype=tensor.dtype),
torch.tensor(-16.0, dtype=tensor.dtype),
torch.tensor(16.0, dtype=tensor.dtype),
)
else:
min_val, max_val = tensor.aminmax()
Expand Down Expand Up @@ -80,7 +83,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):


class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(self, qweight, weight_scale, bias):
def __init__(
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Expand All @@ -105,7 +110,13 @@ def forward(self, x):


class FP8StaticLinear(torch.nn.Module):
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
def __init__(
self,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
act_scale: float = 1.0,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Expand Down Expand Up @@ -133,7 +144,7 @@ def forward(self, x):


class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight, scale, bias):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
Expand All @@ -152,21 +163,28 @@ def forward(self, x):
return output


def replace_module(model, name, new_module):
def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.Module):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.model.get_submodule(parent_name)
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model.model
parent = model
child_name = name
setattr(parent, child_name, new_module)


def quantize_weights(model):
for name, linear in model.model.named_modules():
if not isinstance(linear, torch.nn.Linear):
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
for name, linear in model.named_modules():
if (
not isinstance(linear, torch.nn.Linear)
or name in quantize_config.ignored_layers
):
continue
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, linear.bias)
Expand All @@ -175,9 +193,17 @@ def quantize_weights(model):
cleanup_memory()


def quantize_activations(model, calibration_tokens):
for name, dynamic_quant_linear in model.model.named_modules():
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
for name, dynamic_quant_linear in model.named_modules():
if (
not isinstance(dynamic_quant_linear, FP8DynamicLinear)
or name in quantize_config.ignored_layers
):
continue
quantizer = FP8StaticLinearQuantizer(
dynamic_quant_linear.weight,
Expand All @@ -196,8 +222,11 @@ def quantize_activations(model, calibration_tokens):
pbar.update(1)

# Replace dynamic quantizer with StaticLinear for export
for name, quantizer in model.model.named_modules():
if not isinstance(quantizer, FP8StaticLinearQuantizer):
for name, quantizer in model.named_modules():
if (
not isinstance(quantizer, FP8StaticLinearQuantizer)
or name in quantize_config.ignored_layers
):
continue
static_proj = FP8StaticLinear(
quantizer.weight,
Expand All @@ -210,13 +239,19 @@ def quantize_activations(model, calibration_tokens):
cleanup_memory()


def save_quantized_model(model, activation_scheme, save_dir):
def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
static_q_dict = {
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": activation_scheme,
"activation_scheme": quant_config.activation_scheme,
"ignored_layers": quant_config.ignored_layers,
}
}
model.config.update(static_q_dict)
Expand Down
7 changes: 5 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand All @@ -9,8 +10,10 @@
examples = tokenizer(examples, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="dynamic"
) # or "static"
quant_method="fp8",
activation_scheme="dynamic", # or "static"
ignore_patterns=["re:.*lm_head"],
)

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
Expand Down
22 changes: 22 additions & 0 deletions example_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(512)
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model.save_quantized(quantized_model_dir)
26 changes: 26 additions & 0 deletions examples/example_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_dir = "Mixtral-8x7B-Instruct-v0.1-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(10))
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="static",
ignore_patterns=["re:.*lm_head", "re:.*gate"],
)

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model.save_quantized(quantized_model_dir)
2 changes: 1 addition & 1 deletion examples/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Tuple

import torch
import transformers
import tqdm
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup

setup(
name="auto_fp8",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_auto_fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import shutil

from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
import shutil


def test_quantization():
Expand Down
Loading