From 6d574affde326df01fbfaf4f112da56fd3a1f018 Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Wed, 11 Dec 2024 20:04:31 +0000 Subject: [PATCH] Update code, flip sparse op operand order to be consistent with dense --- .../cutlass_benchmarks/sp_fp8_benchmarks.py | 12 ++- tests/kernels/test_cutlass.py | 85 ++++++++++++++++++- vllm/_custom_ops.py | 12 +-- 3 files changed, 96 insertions(+), 13 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py b/benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py index 644eccdb2d117..0dd59c708d9cd 100644 --- a/benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py @@ -245,7 +245,6 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, # Create tensors BComps, Es, As, Bs = make_n_rand_sparse_tensors( kernel_config.get('arg_pool_size', 1), dtype, m, n, k) - AsT = [x.t() for x in As] bf16_As = [x.to(dtype=torch.bfloat16) for x in As] bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) @@ -304,8 +303,8 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, elif kernel_type == 'cutlass_scaled_sparse_mm': bench = BenchMM(cuda_graph_params, label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, ArgPool(BComps), - ArgPool(Es), ArgPool(AsT), scale_b, scale_a, + ops.cutlass_scaled_sparse_mm, ArgPool(As), + ArgPool(BComps), ArgPool(Es), scale_a, scale_b, torch.bfloat16) # Run the benchmark @@ -430,7 +429,6 @@ def run_kernels_on_gpus( # Create tensors BComps, Es, As, Bs = make_n_rand_sparse_tensors( config.get('arg_pool_size', 1), dtype, m, n, k) - AsT = [x.t() for x in As] bf16_As = [x.to(dtype=torch.bfloat16) for x in As] bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) @@ -491,9 +489,9 @@ def run_kernels_on_gpus( elif kernel_type == 'cutlass_scaled_sparse_mm': bench = BenchMM(cuda_graph_params, label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, - ArgPool(BComps), ArgPool(Es), ArgPool(AsT), - scale_b, scale_a, torch.bfloat16) + ops.cutlass_scaled_sparse_mm, ArgPool(As), + ArgPool(BComps), ArgPool(Es), + scale_a, scale_b, torch.bfloat16) # Run the benchmark result = bench.run() diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index afe53797322f9..2c5d19cc54c54 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,7 +2,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ -from typing import Optional, Type +from typing import Optional, Type, Tuple import pytest import torch @@ -55,6 +55,61 @@ def rand_int8(shape: tuple, device: str = "cuda"): return to_int8(torch.rand(shape, device=device) * 255 - 128) +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_compress_entry(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + def baseline_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -403,6 +458,34 @@ def test_cutlass_subset(): torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) +# Test working with a subset of A and B for sparse matmul +def test_cutlass_sparse_subset(): + big_m = 1024 + m, n, k = 512, 512, 512 + + # Create tensors + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k) + a = whole_a[0:m, 0:k] + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + print("in test") + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + # Test to make sure cuda graphs work class CutlassLayer(torch.nn.Module): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 230a0ad7aaa5c..c89c7d492f75d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -560,9 +560,9 @@ def cutlass_compress_entry(a: torch.Tensor) \ def cutlass_scaled_sparse_mm( - a: torch.Tensor, + a: torch.Tensor, # row-major activations + b: torch.Tensor, # row-major weight matrix e: torch.Tensor, - b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, @@ -572,11 +572,13 @@ def cutlass_scaled_sparse_mm( assert bias is None or bias.shape[0] == a.shape[0] \ and bias.dtype == out_dtype - m = a.shape[0] - n = b.shape[1] + a_t = a.t() + + m = b.shape[0] + n = a_t.shape[1] out = torch.empty((n, m), dtype=out_dtype, device=a.device).t() - torch.ops._C.cutlass_scaled_sparse_mm(out, a, e, b, scale_a, scale_b, bias) + torch.ops._C.cutlass_scaled_sparse_mm(out, b, e, a_t, scale_b, scale_a, bias) return out.t()