Skip to content

Commit

Permalink
Review Comments!
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 20, 2024
1 parent 28c9d99 commit 4ef03bd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def compress_weight(
zero_point=zero_point,
g_idx=g_idx,
args=quantization_args,
dtype=quantization_args.pytorch_dtype(),
dtype=torch.int8,
)
else:
quantized_weight = weight
Expand Down
19 changes: 18 additions & 1 deletion src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_nested_weight_mappings(
each layer's compression parameters.
Example of the nested mapping:
layer.weight: {
layer: {
bitmask: file_location,
row_offsets: file_location,
shape: file_location,
Expand Down Expand Up @@ -247,6 +247,23 @@ def get_nested_weight_mappings(


def get_nested_mappings_from_state_dict(state_dict, params_to_nest):
"""
Takes a state dict and returns a nested mapping from uncompressed
parameterized layer names to the value of
each layer's compression parameters.
Example of the nested mapping:
layer: {
weight_scale: ...,
weight: ...,
zero_point: ...,
}
:param state_dict: state dict of the model
:param params_to_nest: List of parameter names to nest.
:return: Nested mapping of parameterized layer names to the value of
each layer's compression parameters.
"""
nested_weight_mappings = {}
for key in state_dict.keys():
for param_name in params_to_nest:
Expand Down
17 changes: 10 additions & 7 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_bitmask_sparsity_config():
)


def get_quantization_config(bits=8, type="int", strategy="tensor"):
def create_quantization_config(bits=8, type="int", strategy="tensor"):

config_dict = {
"format": "int-quantized",
Expand All @@ -183,8 +183,8 @@ def get_quantization_config(bits=8, type="int", strategy="tensor"):
@pytest.mark.parametrize(
"quantization_config",
[
get_quantization_config(bits=8, type="int", strategy="channel"),
get_quantization_config(bits=8, type="float", strategy="channel"),
create_quantization_config(bits=8, type="int", strategy="channel"),
create_quantization_config(bits=8, type="float", strategy="channel"),
],
)
def test_composability(
Expand All @@ -211,12 +211,15 @@ def test_composability(
args=quantization_args,
)

model = fake_model_class(quantized_weights, scale, zero_point)
model.linear.quantization_scheme = quantization_config.config_groups["group_0"]
fake_oneshot_model = fake_model_class(quantized_weights, scale, zero_point)
fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[
"group_0"
]
model_compressor = ModelCompressor(
sparsity_config=sparsity_config, quantization_config=quantization_config
)
compressed_state_dict = model_compressor.compress(model)
# does both sparse and quantization compression
compressed_state_dict = model_compressor.compress(fake_oneshot_model)

save_dir = tmp_path / "model"
save_dir = _create_dummy_checkpoint(
Expand All @@ -227,7 +230,7 @@ def test_composability(
model_compressor.decompress(model=decompressed_model, model_path=save_dir)

# check that the decompressed model is the same as the original model
_check_state_dicts(model.state_dict(), decompressed_model.state_dict())
_check_state_dicts(fake_oneshot_model.state_dict(), decompressed_model.state_dict())


def _create_dummy_checkpoint(state_dict, save_dir, model_compressor):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_utils/test_safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def mock_get_weight_mappings():

@pytest.mark.usefixtures("mock_get_weight_mappings")
class TestGetNestedWeightMappings:
"""
Tests for the get_nested_weight_mappings function
in different scenarios, such as single and multiple
parameters to nest, and returning other parameters
"""

def test_single_param(self):
params_to_nest = ["weight"]
result = get_nested_weight_mappings("dummy_path", params_to_nest)
Expand Down
23 changes: 17 additions & 6 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
import pytest


Expand Down Expand Up @@ -54,7 +54,13 @@ def requires_accelerate():
)


def get_random_mat(M, K, dtype):
def get_random_mat(M, K, dtype) -> "torch.Tensor":
"""
:param M: number of rows
:param K: number of columns
:param dtype: data type of the matrix
:return: random matrix of shape (M, K) with non-zero values
"""
import torch
from compressed_tensors.quantization import FP8_DTYPE

Expand All @@ -66,7 +72,13 @@ def get_random_mat(M, K, dtype):
return mat.to(dtype)


def generate_pruned_semi_structured_mat(M, K, dtype):
def generate_pruned_semi_structured_mat(M, K, dtype) -> "torch.Tensor":
"""
:param M: number of rows
:param K: number of columns
:param dtype: data type of the matrix
:return: random matrix of shape (M, K) with 2:4 sparsity pattern
"""
import torch
from compressed_tensors.quantization import FP8_DTYPE

Expand All @@ -84,15 +96,14 @@ def generate_pruned_semi_structured_mat(M, K, dtype):
return mat.to(dtype)


def induce_sparsity(tensor, sparsity_ratio):
def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor":
"""
Makes a tensor sparse by zeroing out a given fraction
of its smallest absolute values.
:param: weight_tensor (torch.Tensor): The input weight tensor.
:param: sparsity_ratio (float): Fraction of weights to be zeroed
(0 <= sparsity_ratio <= 1).
(0 <= sparsity_ratio <= 1).
:returns: torch.Tensor: Sparse version of the input tensor.
"""
import torch
Expand Down

0 comments on commit 4ef03bd

Please sign in to comment.