From 4a09744f7c8e71c445dbf32348f2d24a17441253 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Oct 2024 13:36:29 -0400 Subject: [PATCH] initial commit (#174) Co-authored-by: kylesayrs --- src/compressed_tensors/base.py | 1 + .../compressors/model_compressor.py | 48 +++++++++++++++++-- .../quantization/quant_scheme.py | 2 +- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index 0a17f49e..0e073262 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -17,3 +17,4 @@ COMPRESSION_CONFIG_NAME = "compression_config" KV_CACHE_SCHEME_NAME = "kv_cache_scheme" COMPRESSION_VERSION_NAME = "version" +QUANTIZATION_METHOD_NAME = "quant_method" diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index e2346a2f..dfbff781 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -20,18 +20,20 @@ from copy import deepcopy from typing import Any, Dict, Optional, Union +import compressed_tensors import torch import transformers -import compressed_tensors from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME, + QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) from compressed_tensors.compressors import Compressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( + DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationStatus, apply_quantization_config, @@ -186,7 +188,17 @@ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]: return compression_config.get(SPARSITY_CONFIG_NAME, None) @staticmethod - def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]: + def parse_quantization_config( + compression_config: Dict[str, Any] + ) -> Union[Dict[str, Any], None]: + """ + Parse quantization config from quantization/compression config. The + quantization are all the fields that are not the sparsity config or + metadata fields + + :param compression_config: quantization/compression config + :return: quantization config without sparsity config or metadata fields + """ if compression_config is None: return None @@ -201,9 +213,20 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]: # SparseAutoModel format quantization_config = deepcopy(compression_config) quantization_config.pop(SPARSITY_CONFIG_NAME, None) - quantization_config.pop(COMPRESSION_VERSION_NAME, None) + + # some fields are required, even if a qconfig is not present + # pop them off and if nothing remains, then there is no qconfig + quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None) + _ = quantization_config.pop(COMPRESSION_VERSION_NAME, None) + if len(quantization_config) == 0: - quantization_config = None + return None + + # replace popped off values + # note that version is discarded for now + if quant_method is not None: + quantization_config[QUANTIZATION_METHOD_NAME] = quant_method + return quantization_config def __init__( @@ -216,7 +239,6 @@ def __init__( self.sparsity_compressor = None self.quantization_compressor = None - if sparsity_config and sparsity_config.format == CompressionFormat.dense.value: # ignore dense sparsity config self.sparsity_config = None @@ -300,6 +322,9 @@ def update_config(self, save_directory: str): :param save_directory: path to a folder containing a HF model config """ + if self.quantization_config is None and self.sparsity_config is None: + return + config_file_path = os.path.join(save_directory, CONFIG_NAME) if not os.path.exists(config_file_path): _LOGGER.warning( @@ -311,7 +336,20 @@ def update_config(self, save_directory: str): with open(config_file_path, "r") as config_file: config_data = json.load(config_file) + # required metadata whenever a quantization or sparsity config is present + # overwrite previous config and version if already existing config_data[QUANTIZATION_CONFIG_NAME] = {} + config_data[QUANTIZATION_CONFIG_NAME][ + COMPRESSION_VERSION_NAME + ] = compressed_tensors.__version__ + if self.quantization_config is not None: + self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD + else: + config_data[QUANTIZATION_CONFIG_NAME][ + QUANTIZATION_METHOD_NAME + ] = DEFAULT_QUANTIZATION_METHOD + + # quantization and sparsity configs if self.quantization_config is not None: quant_config_data = self.quantization_config.model_dump() config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 03748d82..b41eaafb 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -211,7 +211,7 @@ def is_preset_scheme(name: str) -> bool: "W4A16": W4A16, # Integer weight and activation schemes "W8A8": INT8_W8A8, - "INT8": INT8_W8A8, # alias for W8A8 + "INT8": INT8_W8A8, # alias for W8A8 "W4A8": INT8_W4A8, # Float weight and activation schemes "FP8": FP8,