Skip to content

Commit

Permalink
apply style
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Jan 15, 2025
1 parent 2d4791c commit a28f231
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 7 deletions.
55 changes: 55 additions & 0 deletions examples/quantization_w4a16/gptj_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier
from llmcompressor.transformers import oneshot

# Select model and load it.
MODEL_ID = "EleutherAI/gpt-j-6B"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(
model=model,
dataset="ultrachat-200k",
splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"},
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
10 changes: 6 additions & 4 deletions examples/quantizing_moe/deepseek_moe_w4a16.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from llmcompressor.transformers.tracing.deepseek_v2.configuration_deepseek import DeepseekV2Config
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer

from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
from llmcompressor.transformers.tracing import TraceableDeepseekV2ForCausalLM
from llmcompressor.transformers.tracing.deepseek_v2.configuration_deepseek import (
DeepseekV2Config,
)

# NOTE: transformers 4.48.0 has an import error with DeepSeek.
# Please consider either downgrading your transformers version to a
Expand All @@ -24,15 +26,15 @@
trust_remote_code=True,
)

#model = AutoModelForCausalLM.from_pretrained(
# model = AutoModelForCausalLM.from_pretrained(
config = DeepseekV2Config.from_pretrained(MODEL_ID)
config.moe_top_k_activation = True
model = TraceableDeepseekV2ForCausalLM.from_pretrained(
MODEL_ID,
device_map=device_map,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
config=config
config=config,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
from .deepseek_v2.modeling_deepseek import DeepseekV2ForCausalLM as TraceableDeepseekV2ForCausalLM
from .deepseek_v2.modeling_deepseek import (
DeepseekV2ForCausalLM as TraceableDeepseekV2ForCausalLM
)

__all__ = [
"TraceableLlavaForConditionalGeneration",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

Expand Down Expand Up @@ -206,4 +207,4 @@ def __init__(
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
# coding=utf-8
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
Expand Down Expand Up @@ -1925,4 +1926,4 @@ def forward(
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
)

0 comments on commit a28f231

Please sign in to comment.