Skip to content

Commit

Permalink
Some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Nov 15, 2024
1 parent 86716f8 commit c796ac8
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import SparsityCompressionConfig, SparsityStructure

__all__ = ["CompressedTensorsLinearMethod"]

Expand All @@ -37,6 +38,7 @@ def __init__(self,
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None,
model_compressor: Optional[ModelCompressor] = None,
sparsity_scheme_map: Optional[Dict[str, Any]] = None
):

self.ignore = ignore
Expand All @@ -45,14 +47,15 @@ def __init__(self,
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.model_compressor = model_compressor
self.sparsity_scheme_map = sparsity_scheme_map

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []

def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
Expand Down Expand Up @@ -85,10 +88,35 @@ def get_quant_method(

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict()
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))

target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
model_compressor = ModelCompressor.from_compression_config(config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(sparsity_config=model_compressor.sparsity_config)

return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
model_compressor=model_compressor,
sparsity_scheme_map=sparsity_scheme_map,
)

@classmethod
def _sparsity_scheme_map_from_config(cls, sparsity_config: SparsityCompressionConfig):
sparse_targets = cast(List[str], sparsity_config.targets) if sparsity_config else []
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparse_targets
}
return sparse_scheme_map

@classmethod
def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]):
target_scheme_map: Dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))

# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
Expand All @@ -97,13 +125,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
"""
for _, quant_config in config["config_groups"].items():

config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))

target_scheme_map[target]["input_activations"] = None
Expand All @@ -118,31 +147,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
"input_activations"] = QuantizationArgs.model_validate(
quant_config.get("input_activations"))
"""
sparsity_config = config.get("sparsity_config")
targets = sparsity_config.get("targets")
sparsity_format = sparsity_config.get("format")
ignore = sparsity_config.get("ignore")

for t in targets:
target_scheme_map[t] = sparsity_format

model_compressor = ModelCompressor.from_compression_config(config)

"""
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
"""
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=sparsity_format,
model_compressor=model_compressor,
)
return target_scheme_map

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -341,29 +348,94 @@ def get_scheme(
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
"""

matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())

# Find the quant_scheme
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])
weight_quant = scheme_dict["weights"]
input_quant = scheme_dict["input_activations"]
sparsity_scheme: Optional[SparsityCompressionConfig] = self.sparsity_scheme_map.get(matched_target)

if self.supports_cutlass_24(
weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme
):
# Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel
needs_decompression = sparsity_scheme.format != CompressionFormat.dense.value
is_quantized = weight_quant is not None or input_quant is not None

scheme = CompressedTensors24(
model_compressor=self.model_compressor,
layer_name=layer_name,
quantized=is_quantized,
do_decompress=needs_decompression,
)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
)

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
"""
scheme = CompressedTensors24(
model_compressor=self.model_compressor,
layer_name=layer_name
)

return scheme

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig]=None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""

if (
sparsity_scheme is None or
sparsity_scheme.sparsity_structure != SparsityStructure.TWO_FOUR.value
):
return False

# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True

# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False

supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]

if weight_quant.strategy not in supported_weight_quant_strategies:
return False

supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.TOKEN.value
]

if input_quant.strategy not in supported_input_quant_strategies:
return False

return True

class CompressedTensorsLinearMethod(LinearMethodBase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,93 @@
__all__ = ["CompressedTensors24"]

class CompressedTensors24(CompressedTensorsScheme):
def __init__(self, model_compressor: Optional[ModelCompressor] = None, layer_name = None):
def __init__(
self,
model_compressor: Optional[ModelCompressor] = None,
layer_name: Optional[str] = None,
quantized: bool = False,
do_decompress: bool = False,
):

self.model_compressor = model_compressor
self.layer_name = layer_name
self.quantized = True # toggle based on the case we're running
self.compressed = False # toggle based on the case we're running
self.quantized = quantized
self.do_decompress = do_decompress

@classmethod
def get_min_capability(cls) -> int:
"""
Since this scheme uses the cutlass library with FP8, it requires
a minimum capability of 90
:return: The minimum capability required for this scheme
"""
return 90

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs
):
layer.logical_widths = output_partition_sizes
self.params_dtype=params_dtype

# parameter to store uncompressed weight or decompressed weight
weight = ModelWeightParameter(
data=torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.float8_e4m3fn),
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

if self.do_decompress:
# store compression specific things to be used
# later during decompression

bits_per_weight_element = weight.itemsize * 8
meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16

# compressed weight for 2:4 sparse
weight_packed = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader
)

meta_input_size = (
input_size_per_partition // 32
if bits_per_weight_element == 8
else input_size_per_partition // 16
)
# meta tensor for 2:4 decompression
meta = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
meta_input_size,
dtype=meta_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_packed", weight_packed)
layer.register_parameter("meta", meta)

if self.quantized:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand All @@ -56,7 +109,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""

w_compressed, meta = ops.cutlass_compress_entry(layer.weight)
decompressed_weight = (
layer.weight if not self.do_decompress
else self._decompress_24_weight(layer.weight_packed.data, layer.meta.data)
)
w_compressed, meta = ops.cutlass_compress_entry(decompressed_weight)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)

Expand Down Expand Up @@ -101,6 +158,43 @@ def apply_weights(self,
out = out.contiguous()

return out


def _decompress_24_weight(self, weight_packed: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
qkv_sizes = [2048, 256, 256]
gate_up_sizes = [5632, 5632]
split_weights = None
split_meta = None

def _process_split(input_weight, input_meta):
weight_data = {
"weight_packed": input_weight,
"meta": input_meta
}
decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
return decompress

print(self.layer_name)
if "qkv" in self.layer_name:
split_weights = torch.split(weight_packed, qkv_sizes)
split_meta = torch.split(meta, qkv_sizes)
elif "gate_up" in self.layer_name:
split_weights = torch.split(weight_packed, gate_up_sizes)
split_meta = torch.split(meta, gate_up_sizes)

if split_weights:
all_compress = []
for i in range(len(split_weights)):
print(split_weights[i].shape, split_meta[i].shape)
compress_i = _process_split(split_weights[i], split_meta[i])
all_compress.append(compress_i)

decompressed = torch.cat(all_compress)
else:
decompressed = _process_split(weight_packed, meta)

return decompressed



def check_24(tensor):
Expand Down

0 comments on commit c796ac8

Please sign in to comment.