Skip to content

Commit

Permalink
Update code, flip sparse op operand order to be consistent with dense
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 11, 2024
1 parent 81c4360 commit 6d574af
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 13 deletions.
12 changes: 5 additions & 7 deletions benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int,
# Create tensors
BComps, Es, As, Bs = make_n_rand_sparse_tensors(
kernel_config.get('arg_pool_size', 1), dtype, m, n, k)
AsT = [x.t() for x in As]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
Expand Down Expand Up @@ -304,8 +303,8 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int,
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, ArgPool(BComps),
ArgPool(Es), ArgPool(AsT), scale_b, scale_a,
ops.cutlass_scaled_sparse_mm, ArgPool(As),
ArgPool(BComps), ArgPool(Es), scale_a, scale_b,
torch.bfloat16)

# Run the benchmark
Expand Down Expand Up @@ -430,7 +429,6 @@ def run_kernels_on_gpus(
# Create tensors
BComps, Es, As, Bs = make_n_rand_sparse_tensors(
config.get('arg_pool_size', 1), dtype, m, n, k)
AsT = [x.t() for x in As]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
Expand Down Expand Up @@ -491,9 +489,9 @@ def run_kernels_on_gpus(
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
ArgPool(BComps), ArgPool(Es), ArgPool(AsT),
scale_b, scale_a, torch.bfloat16)
ops.cutlass_scaled_sparse_mm, ArgPool(As),
ArgPool(BComps), ArgPool(Es),
scale_a, scale_b, torch.bfloat16)

# Run the benchmark
result = bench.run()
Expand Down
85 changes: 84 additions & 1 deletion tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Optional, Type
from typing import Optional, Type, Tuple

import pytest
import torch
Expand Down Expand Up @@ -55,6 +55,61 @@ def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)


def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)


def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)


def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)

# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)

# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))

# Apply mask and reshape back
pruned = reshaped * mask

# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0

return pruned.reshape(original_shape)


def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

b = prune_to_2_4(b.t()).t()

if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")

b_compressed, e = ops.cutlass_compress_entry(b.t())

# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b

Check failure on line 110 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible return value type (got "tuple[Any, Any, Any, Any]", expected "tuple[Any, Any]") [return-value]

Check failure on line 110 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible return value type (got "tuple[Any, Any, Any, Any]", expected "tuple[Any, Any]") [return-value]

Check failure on line 110 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible return value type (got "tuple[Any, Any, Any, Any]", expected "tuple[Any, Any]") [return-value]

Check failure on line 110 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible return value type (got "tuple[Any, Any, Any, Any]", expected "tuple[Any, Any]") [return-value]


def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
Expand Down Expand Up @@ -403,6 +458,34 @@ def test_cutlass_subset():
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512

# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)

Check failure on line 467 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Need more than 2 values to unpack (4 expected) [misc]

Check failure on line 467 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/kernels/test_cutlass.py:467:81: E501 Line too long (86 > 80)

Check failure on line 467 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Need more than 2 values to unpack (4 expected) [misc]

Check failure on line 467 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Need more than 2 values to unpack (4 expected) [misc]

Check failure on line 467 in tests/kernels/test_cutlass.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Need more than 2 values to unpack (4 expected) [misc]
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

print("in test")

out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)

torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

Expand Down
12 changes: 7 additions & 5 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ def cutlass_compress_entry(a: torch.Tensor) \


def cutlass_scaled_sparse_mm(
a: torch.Tensor,
a: torch.Tensor, # row-major activations
b: torch.Tensor, # row-major weight matrix
e: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
Expand All @@ -572,11 +572,13 @@ def cutlass_scaled_sparse_mm(
assert bias is None or bias.shape[0] == a.shape[0] \
and bias.dtype == out_dtype

m = a.shape[0]
n = b.shape[1]
a_t = a.t()

m = b.shape[0]
n = a_t.shape[1]
out = torch.empty((n, m), dtype=out_dtype, device=a.device).t()

torch.ops._C.cutlass_scaled_sparse_mm(out, a, e, b, scale_a, scale_b, bias)
torch.ops._C.cutlass_scaled_sparse_mm(out, b, e, a_t, scale_b, scale_a, bias)

return out.t()

Expand Down

0 comments on commit 6d574af

Please sign in to comment.