From b0d36af7e221aa0c0f0cd4cb6d3addcb3b8ed2e6 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 6 Jan 2025 22:25:03 +0000 Subject: [PATCH] review comments from @kylesayrs --- .../compressors/model_compressors/model_compressor.py | 4 +--- .../compressors/quantized_compressors/naive_quantized.py | 2 +- .../compressors/quantized_compressors/pack_quantized.py | 2 +- .../compressors/sparse_compressors/base.py | 7 +++---- src/compressed_tensors/quantization/lifecycle/forward.py | 2 +- src/compressed_tensors/utils/safetensors_load.py | 4 +++- 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 65e808b4..5ea26705 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -411,9 +411,7 @@ def _replace_weights(self, dense_weight_generator, model: Module): prefix, param_name = ".".join(split_name[:-1]), split_name[-1] module = operator.attrgetter(prefix)(model) if hasattr(module, param_name): - update_parameter_data( - module=module, new_param_data=data, param_name=param_name - ) + update_parameter_data(module, data, param_name) def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]: diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index eea50848..69d9d596 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -78,8 +78,8 @@ def compress_weight( :param weight: uncompressed weight tensor :param scale: quantization scale for weight - :param zero_point: quantization zero point for weight :param quantization_args: quantization parameters for weight + :param zero_point: quantization zero point for weight :param g_idx: optional mapping from column index to group index :param device: optional device to move compressed output to :return: dictionary of compressed weight data diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 8f694c8b..629ef37e 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -78,9 +78,9 @@ def compress_weight( :param weight: uncompressed weight tensor :param scale: quantization scale for weight + :param quantization_args: quantization parameters for weight :param zero_point: quantization zero point for weight :param g_idx: optional mapping from column index to group index - :param quantization_args: quantization parameters for weight :param device: optional device to move compressed output to :return: dictionary of compressed weight data """ diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 72232801..7cd6e8e8 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -69,8 +69,8 @@ def compress( Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model - :param compression_targets: optional set of layer prefixes to compress, if None - compress all layers (for backwards compatibility) + :param compression_targets: optional set of layer prefixes to compress, + otherwise compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -78,8 +78,7 @@ def compress( f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): - ignored = not self.should_compress(name, compression_targets) - if ignored: + if not self.should_compress(name, compression_targets): compressed_dict[name] = value continue prefix = name diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 5426a226..f4f93f27 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -82,7 +82,7 @@ def quantize( def dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor = None, + zero_point: Optional[torch.Tensor] = None, args: Optional[QuantizationArgs] = None, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index f7569b98..ab4d04bf 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -246,7 +246,9 @@ def get_nested_weight_mappings( return nested_weight_mappings -def get_nested_mappings_from_state_dict(state_dict, params_to_nest): +def get_nested_mappings_from_state_dict( + state_dict, params_to_nest +) -> NestedWeightMappingType: """ Takes a state dict and returns a nested mapping from uncompressed parameterized layer names to the value of