Skip to content

Commit

Permalink
[Bugfix] Fix quant config parsing (#162)
Browse files Browse the repository at this point in the history
* fix quant config parsing

* add file

---------

Co-authored-by: Kyle Sayers <kyle@neuralmagic.com>
  • Loading branch information
kylesayrs and Kyle Sayers authored Sep 25, 2024
1 parent 74f1aa6 commit 44dda93
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
QUANTIZATION_CONFIG_NAME = "quantization_config"
COMPRESSION_CONFIG_NAME = "compression_config"
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
COMPRESSION_VERSION_NAME = "version"
6 changes: 5 additions & 1 deletion src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import compressed_tensors
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
SPARSITY_CONFIG_NAME,
)
Expand Down Expand Up @@ -200,6 +201,7 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
# SparseAutoModel format
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
quantization_config.pop(COMPRESSION_VERSION_NAME, None)
if len(quantization_config) == 0:
quantization_config = None
return quantization_config
Expand Down Expand Up @@ -313,7 +315,9 @@ def update_config(self, save_directory: str):
config_data[COMPRESSION_CONFIG_NAME][
SPARSITY_CONFIG_NAME
] = sparsity_config_data
config_data[COMPRESSION_CONFIG_NAME]["version"] = compressed_tensors.__version__
config_data[COMPRESSION_CONFIG_NAME][
COMPRESSION_VERSION_NAME
] = compressed_tensors.__version__

with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)
Expand Down

0 comments on commit 44dda93

Please sign in to comment.