diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 3fa8e75..a4f82cc 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -7,6 +7,7 @@ from llmcompressor.transformers import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier + class BaseQuantizeConfig: """Configuration for model quantization. @@ -24,6 +25,7 @@ class BaseQuantizeConfig: By default, "lm_head" is included to ignore the embedding Linear layer usually at the end of decoder LLMs """ + def __init__( self, quant_method: str = "fp8", @@ -36,32 +38,41 @@ def __init__( class AutoFP8ForCausalLM: - def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig): + def __init__( + self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig + ): self.model = model self.model_type = self.model.config.model_type self.config = self.model.config self.quantize_config = quantize_config @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + quantize_config: BaseQuantizeConfig, + **kwargs, + ): config = AutoConfig.from_pretrained(pretrained_model_name_or_path) model = SparseAutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, config=config, device_map="auto", torch_dtype="auto", - **kwargs + **kwargs, ) return cls(model, quantize_config) def quantize(self, dataset: Optional[Dataset] = None): - assert self.quantize_config.activation_scheme == "static" - assert dataset is not None, "Calibration tokens required for static activation quantization" + assert ( + self.quantize_config.activation_scheme == "static" + ), "Dynamic isn't supported yet" + assert ( + dataset is not None + ), "Calibration tokens required for static activation quantization" recipe = QuantizationModifier( - targets="Linear", - scheme="FP8", - ignore=self.quantize_config.ignore_patterns + targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns ) oneshot( @@ -73,7 +84,7 @@ def quantize(self, dataset: Optional[Dataset] = None): def save_quantized(self, save_directory: str): self.save_pretrained(save_directory, save_compressed=True) - def save_pretrained(self, save_directory: str, save_compressed: bool = True): + def save_pretrained(self, save_directory: str, save_compressed: bool = True): self.model.save_pretrained(save_directory, save_compressed=save_compressed) tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path) tokenizer.save_pretrained(save_directory)