diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index 024e415d..0a17f49e 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -16,3 +16,4 @@ QUANTIZATION_CONFIG_NAME = "quantization_config" COMPRESSION_CONFIG_NAME = "compression_config" KV_CACHE_SCHEME_NAME = "kv_cache_scheme" +COMPRESSION_VERSION_NAME = "version" diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 46a2d708..2a8e02c3 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -25,6 +25,7 @@ import compressed_tensors from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, + COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME, SPARSITY_CONFIG_NAME, ) @@ -200,6 +201,7 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]: # SparseAutoModel format quantization_config = deepcopy(compression_config) quantization_config.pop(SPARSITY_CONFIG_NAME, None) + quantization_config.pop(COMPRESSION_VERSION_NAME, None) if len(quantization_config) == 0: quantization_config = None return quantization_config @@ -313,7 +315,9 @@ def update_config(self, save_directory: str): config_data[COMPRESSION_CONFIG_NAME][ SPARSITY_CONFIG_NAME ] = sparsity_config_data - config_data[COMPRESSION_CONFIG_NAME]["version"] = compressed_tensors.__version__ + config_data[COMPRESSION_CONFIG_NAME][ + COMPRESSION_VERSION_NAME + ] = compressed_tensors.__version__ with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True)