Skip to content

Commit

Permalink
Add tests fro int8
Browse files Browse the repository at this point in the history
Add a requires gpu decorator in testing_utils
Enable fp8 tests if gpu available
  • Loading branch information
rahul-tuli committed Jan 8, 2025
1 parent 00e9b9b commit 5f5b5c3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 13 deletions.
85 changes: 72 additions & 13 deletions tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from compressed_tensors import Sparse24BitMaskTensor
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import combine_shards, shard_tensor
from tests.testing_utils import generate_pruned_semi_structured_mat
from tests.testing_utils import generate_pruned_semi_structured_mat, requires_gpu


@pytest.fixture
Expand All @@ -45,8 +45,16 @@ def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes):
return _validate_shard_shapes


@pytest.mark.parametrize("dtype", [FP8_DTYPE])
def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype):
def validate_compression(dense_matrix, decompressed_tensor):
"""Validate that the decompressed tensor matches the original dense matrix."""
dense_matrix = dense_matrix.to(decompressed_tensor.device)
assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch"
assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch"
assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed"


@pytest.mark.parametrize("dtype", [torch.int8])
def test_bitmask_compress_decompress(dense_matrix_fixture, dtype):
M, K = 1024, 1024
dense_matrix = dense_matrix_fixture(M, K, dtype)

Expand All @@ -55,18 +63,14 @@ def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype):
)
decompressed_tensor = bitmask_tensor.decompress()

dense_matrix = dense_matrix.to(decompressed_tensor.device)

assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch"
assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch"
assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed"
validate_compression(dense_matrix, decompressed_tensor)


@pytest.mark.parametrize(
"dtype, M, K, shard_sizes, shard_dim, expected_shapes",
[
(
FP8_DTYPE,
torch.int8,
2560,
2048,
[2048, 256, 256],
Expand All @@ -78,7 +82,7 @@ def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype):
],
),
(
FP8_DTYPE,
torch.int8,
2048,
2048,
[1024, 1024],
Expand Down Expand Up @@ -136,7 +140,62 @@ def test_bitmask_compress_decompress_sharded(
]

decompressed_combined = combine_shards(decompressed_shards, dim=shard_dim)
validate_compression(dense_matrix, decompressed_combined)

assert dense_matrix.dtype == decompressed_combined.dtype, "Dtype mismatch"
assert dense_matrix.shape == decompressed_combined.shape, "Shape mismatch"
assert torch.equal(dense_matrix, decompressed_combined), "Decompression failed"

# GPU-Specific Tests for FP8_DTYPE
@pytest.mark.parametrize("dtype", [FP8_DTYPE])
@requires_gpu
def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype):
test_bitmask_compress_decompress(dense_matrix_fixture, dtype)


@pytest.mark.parametrize(
"dtype, M, K, shard_sizes, shard_dim, expected_shapes",
[
(
FP8_DTYPE,
2560,
2048,
[2048, 256, 256],
0,
[
{"compressed": (2048, 1024), "bitmask": (2048, 2048 // 8)},
{"compressed": (256, 1024), "bitmask": (256, 2048 // 8)},
{"compressed": (256, 1024), "bitmask": (256, 2048 // 8)},
],
),
(
FP8_DTYPE,
2048,
2048,
[1024, 1024],
1,
[
{"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)},
{"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)},
],
),
],
)
@requires_gpu
def test_bitmask_compress_decompress_sharded_fp8(
dense_matrix_fixture,
shard_validation,
dtype,
M,
K,
shard_sizes,
shard_dim,
expected_shapes,
):
test_bitmask_compress_decompress_sharded(
dense_matrix_fixture,
shard_validation,
dtype,
M,
K,
shard_sizes,
shard_dim,
expected_shapes,
)
18 changes: 18 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
import unittest

import pytest


Expand Down Expand Up @@ -124,3 +126,19 @@ def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor":
sparse_tensor = tensor

return sparse_tensor


def is_gpu_available():
"""
Check for GPU and warn if not found
"""
try:
import torch # noqa: F401

return torch.cuda.device_count() > 0
except ImportError:
return False


def requires_gpu(test_case):
return unittest.skipUnless(is_gpu_available(), "test requires GPU")(test_case)

0 comments on commit 5f5b5c3

Please sign in to comment.