Skip to content

Commit

Permalink
remove and add comment info
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Nov 22, 2024
1 parent 65f6acf commit 60a6940
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand Down Expand Up @@ -267,8 +268,9 @@ def compress(

compressed_state_dict = state_dict

# submodule name to q_args
quantized_modules_to_args = map_modules_to_quant_args(model)
quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
Expand Down Expand Up @@ -373,7 +375,13 @@ def _replace_weights(self, dense_weight_generator, model):
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict:
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
"""
Given a pytorch model, map out the submodule name (usually linear layers)
to the QuantizationArgs
:param model: pytorch model
"""
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
Expand Down

0 comments on commit 60a6940

Please sign in to comment.