Skip to content

Commit

Permalink
initial commit (#174)
Browse files Browse the repository at this point in the history
Co-authored-by: kylesayrs <kyle@neuralmagic.com>
  • Loading branch information
kylesayrs and kylesayrs authored Oct 1, 2024
1 parent 672551d commit 4a09744
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
COMPRESSION_CONFIG_NAME = "compression_config"
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
COMPRESSION_VERSION_NAME = "version"
QUANTIZATION_METHOD_NAME = "quant_method"
48 changes: 43 additions & 5 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Union

import compressed_tensors
import torch
import transformers
import compressed_tensors
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
QUANTIZATION_METHOD_NAME,
SPARSITY_CONFIG_NAME,
)
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
QuantizationConfig,
QuantizationStatus,
apply_quantization_config,
Expand Down Expand Up @@ -186,7 +188,17 @@ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
return compression_config.get(SPARSITY_CONFIG_NAME, None)

@staticmethod
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
def parse_quantization_config(
compression_config: Dict[str, Any]
) -> Union[Dict[str, Any], None]:
"""
Parse quantization config from quantization/compression config. The
quantization are all the fields that are not the sparsity config or
metadata fields
:param compression_config: quantization/compression config
:return: quantization config without sparsity config or metadata fields
"""
if compression_config is None:
return None

Expand All @@ -201,9 +213,20 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
# SparseAutoModel format
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
quantization_config.pop(COMPRESSION_VERSION_NAME, None)

# some fields are required, even if a qconfig is not present
# pop them off and if nothing remains, then there is no qconfig
quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
_ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)

if len(quantization_config) == 0:
quantization_config = None
return None

# replace popped off values
# note that version is discarded for now
if quant_method is not None:
quantization_config[QUANTIZATION_METHOD_NAME] = quant_method

return quantization_config

def __init__(
Expand All @@ -216,7 +239,6 @@ def __init__(
self.sparsity_compressor = None
self.quantization_compressor = None


if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
# ignore dense sparsity config
self.sparsity_config = None
Expand Down Expand Up @@ -300,6 +322,9 @@ def update_config(self, save_directory: str):
:param save_directory: path to a folder containing a HF model config
"""
if self.quantization_config is None and self.sparsity_config is None:
return

config_file_path = os.path.join(save_directory, CONFIG_NAME)
if not os.path.exists(config_file_path):
_LOGGER.warning(
Expand All @@ -311,7 +336,20 @@ def update_config(self, save_directory: str):
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)

# required metadata whenever a quantization or sparsity config is present
# overwrite previous config and version if already existing
config_data[QUANTIZATION_CONFIG_NAME] = {}
config_data[QUANTIZATION_CONFIG_NAME][
COMPRESSION_VERSION_NAME
] = compressed_tensors.__version__
if self.quantization_config is not None:
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
else:
config_data[QUANTIZATION_CONFIG_NAME][
QUANTIZATION_METHOD_NAME
] = DEFAULT_QUANTIZATION_METHOD

# quantization and sparsity configs
if self.quantization_config is not None:
quant_config_data = self.quantization_config.model_dump()
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def is_preset_scheme(name: str) -> bool:
"W4A16": W4A16,
# Integer weight and activation schemes
"W8A8": INT8_W8A8,
"INT8": INT8_W8A8, # alias for W8A8
"INT8": INT8_W8A8, # alias for W8A8
"W4A8": INT8_W4A8,
# Float weight and activation schemes
"FP8": FP8,
Expand Down

0 comments on commit 4a09744

Please sign in to comment.