diff --git a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py index 4e1b5c47..0e28f004 100644 --- a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +++ b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py @@ -18,7 +18,7 @@ from compressed_tensors import Sparse24BitMaskTensor from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.utils import combine_shards, shard_tensor -from tests.testing_utils import generate_pruned_semi_structured_mat +from tests.testing_utils import generate_pruned_semi_structured_mat, requires_gpu @pytest.fixture @@ -45,8 +45,16 @@ def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes): return _validate_shard_shapes -@pytest.mark.parametrize("dtype", [FP8_DTYPE]) -def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): +def validate_compression(dense_matrix, decompressed_tensor): + """Validate that the decompressed tensor matches the original dense matrix.""" + dense_matrix = dense_matrix.to(decompressed_tensor.device) + assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch" + assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch" + assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed" + + +@pytest.mark.parametrize("dtype", [torch.int8]) +def test_bitmask_compress_decompress(dense_matrix_fixture, dtype): M, K = 1024, 1024 dense_matrix = dense_matrix_fixture(M, K, dtype) @@ -55,18 +63,14 @@ def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): ) decompressed_tensor = bitmask_tensor.decompress() - dense_matrix = dense_matrix.to(decompressed_tensor.device) - - assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch" - assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch" - assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed" + validate_compression(dense_matrix, decompressed_tensor) @pytest.mark.parametrize( "dtype, M, K, shard_sizes, shard_dim, expected_shapes", [ ( - FP8_DTYPE, + torch.int8, 2560, 2048, [2048, 256, 256], @@ -78,7 +82,7 @@ def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): ], ), ( - FP8_DTYPE, + torch.int8, 2048, 2048, [1024, 1024], @@ -136,7 +140,62 @@ def test_bitmask_compress_decompress_sharded( ] decompressed_combined = combine_shards(decompressed_shards, dim=shard_dim) + validate_compression(dense_matrix, decompressed_combined) - assert dense_matrix.dtype == decompressed_combined.dtype, "Dtype mismatch" - assert dense_matrix.shape == decompressed_combined.shape, "Shape mismatch" - assert torch.equal(dense_matrix, decompressed_combined), "Decompression failed" + +# GPU-Specific Tests for FP8_DTYPE +@pytest.mark.parametrize("dtype", [FP8_DTYPE]) +@requires_gpu +def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): + test_bitmask_compress_decompress(dense_matrix_fixture, dtype) + + +@pytest.mark.parametrize( + "dtype, M, K, shard_sizes, shard_dim, expected_shapes", + [ + ( + FP8_DTYPE, + 2560, + 2048, + [2048, 256, 256], + 0, + [ + {"compressed": (2048, 1024), "bitmask": (2048, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + ], + ), + ( + FP8_DTYPE, + 2048, + 2048, + [1024, 1024], + 1, + [ + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + ], + ), + ], +) +@requires_gpu +def test_bitmask_compress_decompress_sharded_fp8( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, +): + test_bitmask_compress_decompress_sharded( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, + ) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index fe11c8a9..9137fdf8 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # flake8: noqa +import unittest + import pytest @@ -124,3 +126,19 @@ def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor": sparse_tensor = tensor return sparse_tensor + + +def is_gpu_available(): + """ + Check for GPU and warn if not found + """ + try: + import torch # noqa: F401 + + return torch.cuda.device_count() > 0 + except ImportError: + return False + + +def requires_gpu(test_case): + return unittest.skipUnless(is_gpu_available(), "test requires GPU")(test_case)