diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 598dea3d91fd1..5e3b96d15e85e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import SparsityCompressionConfig, SparsityStructure __all__ = ["CompressedTensorsLinearMethod"] @@ -37,6 +38,7 @@ def __init__(self, quant_format: str, kv_cache_scheme: Optional[Dict[str, Any]] = None, model_compressor: Optional[ModelCompressor] = None, + sparsity_scheme_map: Optional[Dict[str, Any]] = None ): self.ignore = ignore @@ -45,6 +47,7 @@ def __init__(self, self.target_scheme_map = target_scheme_map self.kv_cache_scheme = kv_cache_scheme self.model_compressor = model_compressor + self.sparsity_scheme_map = sparsity_scheme_map def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -52,7 +55,7 @@ def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod @@ -85,10 +88,35 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - target_scheme_map: Dict[str, Any] = dict() ignore = cast(List[str], config.get("ignore")) quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config(config=config) + model_compressor = ModelCompressor.from_compression_config(config) + sparsity_scheme_map = cls._sparsity_scheme_map_from_config(sparsity_config=model_compressor.sparsity_config) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + model_compressor=model_compressor, + sparsity_scheme_map=sparsity_scheme_map, + ) + + @classmethod + def _sparsity_scheme_map_from_config(cls, sparsity_config: SparsityCompressionConfig): + sparse_targets = cast(List[str], sparsity_config.targets) if sparsity_config else [] + sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { + target: sparsity_config + for target in sparse_targets + } + return sparse_scheme_map + + @classmethod + def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]): + target_scheme_map: Dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + # The quant_config has multiple config_groups, each containing # an input_activations key with details about how the activations are # quantized, a weights key indicating how the weights are quantized, @@ -97,13 +125,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": # details follow the structure defined by the QuantizationArgs # pydantic model, which is used to verify the structure of the # quant_config and also store the details for later use. - """ - for _, quant_config in config["config_groups"].items(): + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} target_scheme_map[target][ - "weights"] = QuantizationArgs.parse_obj( + "weights"] = QuantizationArgs.model_validate( quant_config.get("weights")) target_scheme_map[target]["input_activations"] = None @@ -118,31 +147,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": "weights"].type == QuantizationType.FLOAT else: target_scheme_map[target][ - "input_activations"] = QuantizationArgs.parse_obj( + "input_activations"] = QuantizationArgs.model_validate( quant_config.get("input_activations")) - """ - sparsity_config = config.get("sparsity_config") - targets = sparsity_config.get("targets") - sparsity_format = sparsity_config.get("format") - ignore = sparsity_config.get("ignore") - - for t in targets: - target_scheme_map[t] = sparsity_format - - model_compressor = ModelCompressor.from_compression_config(config) - - """ - return cls(target_scheme_map=target_scheme_map, - ignore=ignore, - quant_format=quant_format, - kv_cache_scheme=config.get("kv_cache_scheme")) - """ - return cls( - target_scheme_map=target_scheme_map, - ignore=ignore, - quant_format=sparsity_format, - model_compressor=model_compressor, - ) + return target_scheme_map @classmethod def get_config_filenames(cls) -> List[str]: @@ -341,29 +348,94 @@ def get_scheme( # TODO (@robertgshaw): add compressed-tensors as dep # so we do not have to re-write these functions # need to make accelerate optional in ct to do this - """ + matched_target = find_matched_target( layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys()) - # Find the quant_scheme scheme_dict = self.target_scheme_map[matched_target] - scheme = self._get_scheme_from_parts( - weight_quant=scheme_dict["weights"], - input_quant=scheme_dict["input_activations"]) + weight_quant = scheme_dict["weights"] + input_quant = scheme_dict["input_activations"] + sparsity_scheme: Optional[SparsityCompressionConfig] = self.sparsity_scheme_map.get(matched_target) + + if self.supports_cutlass_24( + weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme + ): + # Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel + needs_decompression = sparsity_scheme.format != CompressionFormat.dense.value + is_quantized = weight_quant is not None or input_quant is not None + + scheme = CompressedTensors24( + model_compressor=self.model_compressor, + layer_name=layer_name, + quantized=is_quantized, + do_decompress=needs_decompression, + ) + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( + weight_quant=weight_quant, + input_quant=input_quant, + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) - """ - scheme = CompressedTensors24( - model_compressor=self.model_compressor, - layer_name=layer_name - ) - return scheme + + @staticmethod + def supports_cutlass_24( + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig]=None + ) -> bool: + """ + Check if the layer is supported by the Cutlass 2:4 Kernel + Conditions: + - Overarching condition Sparsity Structure is 2:4 + - Unquantized cases are supported + - Weight only quantization is not-supported + - Supported weight quantization strategies are TENSOR and CHANNEL + - Supported input quantization strategies are TENSOR and TOKEN + + :return: True if the layer is supported by the Cutlass 2:4 Kernel + False otherwise + """ + + if ( + sparsity_scheme is None or + sparsity_scheme.sparsity_structure != SparsityStructure.TWO_FOUR.value + ): + return False + + # Unquantized cases are supported + if weight_quant is None and input_quant is None: + return True + + # Weight only quantization is not-supported + if weight_quant is not None and input_quant is None: + return False + + supported_weight_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.CHANNEL.value + ] + if weight_quant.strategy not in supported_weight_quant_strategies: + return False + + supported_input_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.TOKEN.value + ] + + if input_quant.strategy not in supported_input_quant_strategies: + return False + + return True class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index de2644758edad..5b6b84fd5380a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -12,21 +12,38 @@ __all__ = ["CompressedTensors24"] class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, model_compressor: Optional[ModelCompressor] = None, layer_name = None): + def __init__( + self, + model_compressor: Optional[ModelCompressor] = None, + layer_name: Optional[str] = None, + quantized: bool = False, + do_decompress: bool = False, + ): + self.model_compressor = model_compressor self.layer_name = layer_name - self.quantized = True # toggle based on the case we're running - self.compressed = False # toggle based on the case we're running + self.quantized = quantized + self.do_decompress = do_decompress @classmethod def get_min_capability(cls) -> int: + """ + Since this scheme uses the cutlass library with FP8, it requires + a minimum capability of 90 + + :return: The minimum capability required for this scheme + """ return 90 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs + ): layer.logical_widths = output_partition_sizes self.params_dtype=params_dtype @@ -34,18 +51,54 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight = ModelWeightParameter( data=torch.empty(sum(output_partition_sizes), input_size_per_partition, - dtype=torch.float8_e4m3fn), + dtype=params_dtype), input_dim=1, output_dim=0, weight_loader=weight_loader) + + if self.do_decompress: + # store compression specific things to be used + # later during decompression + + bits_per_weight_element = weight.itemsize * 8 + meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16 + + # compressed weight for 2:4 sparse + weight_packed = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader + ) + + meta_input_size = ( + input_size_per_partition // 32 + if bits_per_weight_element == 8 + else input_size_per_partition // 16 + ) + # meta tensor for 2:4 decompression + meta = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + meta_input_size, + dtype=meta_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight_packed) + layer.register_parameter("meta", meta) + + if self.quantized: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - - layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -56,7 +109,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ - w_compressed, meta = ops.cutlass_compress_entry(layer.weight) + decompressed_weight = ( + layer.weight if not self.do_decompress + else self._decompress_24_weight(layer.weight_packed.data, layer.meta.data) + ) + w_compressed, meta = ops.cutlass_compress_entry(decompressed_weight) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False) @@ -101,6 +158,43 @@ def apply_weights(self, out = out.contiguous() return out + + + def _decompress_24_weight(self, weight_packed: torch.Tensor, meta: torch.Tensor) -> torch.Tensor: + qkv_sizes = [2048, 256, 256] + gate_up_sizes = [5632, 5632] + split_weights = None + split_meta = None + + def _process_split(input_weight, input_meta): + weight_data = { + "weight_packed": input_weight, + "meta": input_meta + } + decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) + return decompress + + print(self.layer_name) + if "qkv" in self.layer_name: + split_weights = torch.split(weight_packed, qkv_sizes) + split_meta = torch.split(meta, qkv_sizes) + elif "gate_up" in self.layer_name: + split_weights = torch.split(weight_packed, gate_up_sizes) + split_meta = torch.split(meta, gate_up_sizes) + + if split_weights: + all_compress = [] + for i in range(len(split_weights)): + print(split_weights[i].shape, split_meta[i].shape) + compress_i = _process_split(split_weights[i], split_meta[i]) + all_compress.append(compress_i) + + decompressed = torch.cat(all_compress) + else: + decompressed = _process_split(weight_packed, meta) + + return decompressed + def check_24(tensor):