Skip to content

Commit

Permalink
review comments from @kylesayrs
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 6, 2025
1 parent 4afe1f0 commit b0d36af
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
module = operator.attrgetter(prefix)(model)
if hasattr(module, param_name):
update_parameter_data(
module=module, new_param_data=data, param_name=param_name
)
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def compress_weight(
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param zero_point: quantization zero point for weight
:param quantization_args: quantization parameters for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param device: optional device to move compressed output to
:return: dictionary of compressed weight data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def compress_weight(
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param quantization_args: quantization parameters for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:param device: optional device to move compressed output to
:return: dictionary of compressed weight data
"""
Expand Down
7 changes: 3 additions & 4 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,16 @@ def compress(
Compresses a dense state dict using bitmask compression
:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
:param compression_targets: optional set of layer prefixes to compress,
otherwise compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
ignored = not self.should_compress(name, compression_targets)
if ignored:
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
prefix = name
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def quantize(
def dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor = None,
zero_point: Optional[torch.Tensor] = None,
args: Optional[QuantizationArgs] = None,
dtype: Optional[torch.dtype] = None,
g_idx: Optional[torch.Tensor] = None,
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def get_nested_weight_mappings(
return nested_weight_mappings


def get_nested_mappings_from_state_dict(state_dict, params_to_nest):
def get_nested_mappings_from_state_dict(
state_dict, params_to_nest
) -> NestedWeightMappingType:
"""
Takes a state dict and returns a nested mapping from uncompressed
parameterized layer names to the value of
Expand Down

0 comments on commit b0d36af

Please sign in to comment.