diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index b5707336..8f694c8b 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -92,7 +92,7 @@ def compress_weight( zero_point=zero_point, g_idx=g_idx, args=quantization_args, - dtype=quantization_args.pytorch_dtype(), + dtype=torch.int8, ) else: quantized_weight = weight diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index de386678..f7569b98 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -188,7 +188,7 @@ def get_nested_weight_mappings( each layer's compression parameters. Example of the nested mapping: - layer.weight: { + layer: { bitmask: file_location, row_offsets: file_location, shape: file_location, @@ -247,6 +247,23 @@ def get_nested_weight_mappings( def get_nested_mappings_from_state_dict(state_dict, params_to_nest): + """ + Takes a state dict and returns a nested mapping from uncompressed + parameterized layer names to the value of + each layer's compression parameters. + + Example of the nested mapping: + layer: { + weight_scale: ..., + weight: ..., + zero_point: ..., + } + + :param state_dict: state dict of the model + :param params_to_nest: List of parameter names to nest. + :return: Nested mapping of parameterized layer names to the value of + each layer's compression parameters. + """ nested_weight_mappings = {} for key in state_dict.keys(): for param_name in params_to_nest: diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 44bf651d..bbde3011 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -157,7 +157,7 @@ def get_bitmask_sparsity_config(): ) -def get_quantization_config(bits=8, type="int", strategy="tensor"): +def create_quantization_config(bits=8, type="int", strategy="tensor"): config_dict = { "format": "int-quantized", @@ -183,8 +183,8 @@ def get_quantization_config(bits=8, type="int", strategy="tensor"): @pytest.mark.parametrize( "quantization_config", [ - get_quantization_config(bits=8, type="int", strategy="channel"), - get_quantization_config(bits=8, type="float", strategy="channel"), + create_quantization_config(bits=8, type="int", strategy="channel"), + create_quantization_config(bits=8, type="float", strategy="channel"), ], ) def test_composability( @@ -211,12 +211,15 @@ def test_composability( args=quantization_args, ) - model = fake_model_class(quantized_weights, scale, zero_point) - model.linear.quantization_scheme = quantization_config.config_groups["group_0"] + fake_oneshot_model = fake_model_class(quantized_weights, scale, zero_point) + fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[ + "group_0" + ] model_compressor = ModelCompressor( sparsity_config=sparsity_config, quantization_config=quantization_config ) - compressed_state_dict = model_compressor.compress(model) + # does both sparse and quantization compression + compressed_state_dict = model_compressor.compress(fake_oneshot_model) save_dir = tmp_path / "model" save_dir = _create_dummy_checkpoint( @@ -227,7 +230,7 @@ def test_composability( model_compressor.decompress(model=decompressed_model, model_path=save_dir) # check that the decompressed model is the same as the original model - _check_state_dicts(model.state_dict(), decompressed_model.state_dict()) + _check_state_dicts(fake_oneshot_model.state_dict(), decompressed_model.state_dict()) def _create_dummy_checkpoint(state_dict, save_dir, model_compressor): diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py index 95cd6347..932a8926 100644 --- a/tests/test_utils/test_safetensors_load.py +++ b/tests/test_utils/test_safetensors_load.py @@ -38,6 +38,12 @@ def mock_get_weight_mappings(): @pytest.mark.usefixtures("mock_get_weight_mappings") class TestGetNestedWeightMappings: + """ + Tests for the get_nested_weight_mappings function + in different scenarios, such as single and multiple + parameters to nest, and returning other parameters + """ + def test_single_param(self): params_to_nest = ["weight"] result = get_nested_weight_mappings("dummy_path", params_to_nest) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 0cbca8a8..fe11c8a9 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# flake8: noqa import pytest @@ -54,7 +54,13 @@ def requires_accelerate(): ) -def get_random_mat(M, K, dtype): +def get_random_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with non-zero values + """ import torch from compressed_tensors.quantization import FP8_DTYPE @@ -66,7 +72,13 @@ def get_random_mat(M, K, dtype): return mat.to(dtype) -def generate_pruned_semi_structured_mat(M, K, dtype): +def generate_pruned_semi_structured_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with 2:4 sparsity pattern + """ import torch from compressed_tensors.quantization import FP8_DTYPE @@ -84,15 +96,14 @@ def generate_pruned_semi_structured_mat(M, K, dtype): return mat.to(dtype) -def induce_sparsity(tensor, sparsity_ratio): +def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor": """ Makes a tensor sparse by zeroing out a given fraction of its smallest absolute values. :param: weight_tensor (torch.Tensor): The input weight tensor. :param: sparsity_ratio (float): Fraction of weights to be zeroed - (0 <= sparsity_ratio <= 1). - + (0 <= sparsity_ratio <= 1). :returns: torch.Tensor: Sparse version of the input tensor. """ import torch