Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 18, 2024
1 parent 00a8cd8 commit c15e352
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier


class BaseQuantizeConfig:
"""Configuration for model quantization.
Expand All @@ -24,6 +25,7 @@ class BaseQuantizeConfig:
By default, "lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
"""

def __init__(
self,
quant_method: str = "fp8",
Expand All @@ -36,32 +38,41 @@ def __init__(


class AutoFP8ForCausalLM:
def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig):
def __init__(
self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig
):
self.model = model
self.model_type = self.model.config.model_type
self.config = self.model.config
self.quantize_config = quantize_config

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
**kwargs,
):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
model = SparseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
config=config,
device_map="auto",
torch_dtype="auto",
**kwargs
**kwargs,
)
return cls(model, quantize_config)

def quantize(self, dataset: Optional[Dataset] = None):
assert self.quantize_config.activation_scheme == "static"
assert dataset is not None, "Calibration tokens required for static activation quantization"
assert (
self.quantize_config.activation_scheme == "static"
), "Dynamic isn't supported yet"
assert (
dataset is not None
), "Calibration tokens required for static activation quantization"

recipe = QuantizationModifier(
targets="Linear",
scheme="FP8",
ignore=self.quantize_config.ignore_patterns
targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns
)

oneshot(
Expand All @@ -73,7 +84,7 @@ def quantize(self, dataset: Optional[Dataset] = None):
def save_quantized(self, save_directory: str):
self.save_pretrained(save_directory, save_compressed=True)

def save_pretrained(self, save_directory: str, save_compressed: bool = True):
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
tokenizer.save_pretrained(save_directory)
Expand Down

0 comments on commit c15e352

Please sign in to comment.