diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 5d1030d8..68bd52ec 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -38,6 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, iter_named_leaf_modules, @@ -267,8 +268,9 @@ def compress( compressed_state_dict = state_dict - # submodule name to q_args - quantized_modules_to_args = map_modules_to_quant_args(model) + quantized_modules_to_args: Dict[ + str, QuantizationArgs + ] = map_modules_to_quant_args(model) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( @@ -373,7 +375,13 @@ def _replace_weights(self, dense_weight_generator, model): update_parameter_data(module, data, param_name) -def map_modules_to_quant_args(model: Module) -> Dict: +def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]: + """ + Given a pytorch model, map out the submodule name (usually linear layers) + to the QuantizationArgs + + :param model: pytorch model + """ quantized_modules_to_args = {} for name, submodule in iter_named_leaf_modules(model): if is_module_quantized(submodule):