From b559b6a3fb1f65cb3e89378109c68374dc7cf355 Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Fri, 13 Dec 2024 07:45:06 +0000 Subject: [PATCH] Push activations and output transposes into CUTLASS code --- benchmarks/benchmark_throughput.py | 3 +- .../cutlass_benchmarks/sparse_benchmarks.py | 20 +++ benchmarks/cutlass_benchmarks/utils.py | 9 +- csrc/ops.h | 4 +- csrc/sparse/cutlass/sparse_compressor.cu | 7 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 152 +++++++++--------- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 77 ++++----- csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 24 +-- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_semi_structured.py | 2 +- vllm/_custom_ops.py | 89 +++++++--- .../schemes/compressed_tensors_24.py | 6 +- 12 files changed, 236 insertions(+), 165 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e92b5d00dc9f5..1e5967bd9bf8b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -361,8 +361,7 @@ def main(args: argparse.Namespace): # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s, " - f"{total_num_tokens=} | {total_output_tokens=}") + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") # Output JSON results if specified if args.output_json: diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 0bbed68a71e67..eec6e6134a0cf 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -46,6 +46,16 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + timers = [] # pytorch impl - bfloat16 timers.append( @@ -95,6 +105,16 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + timers = [] # pytorch impl w. bf16 diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index 84937d1c81bb2..c53cee52642f4 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -62,8 +62,11 @@ def prune_to_2_4(tensor): def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 + # a = torch.randn((m, k), device='cuda') * 5 + # b = torch.randn((n, k), device='cuda').t() * 5 + + a = torch.ones((m, k), device='cuda') + b = torch.ones((n, k), device='cuda').t() b = prune_to_2_4(b.t()).t() @@ -78,7 +81,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, else: raise ValueError("unsupported dtype") - b_compressed, e = ops.cutlass_compress_entry(b.t()) + b_compressed, e = ops.cutlass_sparse_compress(b.t()) # Compressed B, Metadata, Original A, B return b_compressed, e, a, b diff --git a/csrc/ops.h b/csrc/ops.h index 363ddec3d0729..d43f495aabd80 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -156,12 +156,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, c10::optional const& bias); void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& e, torch::Tensor const& b, + torch::Tensor const& b, torch::Tensor const& e, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); -bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, +bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a); #endif diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu index ebb1c975121ac..30b78054f300e 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -73,9 +73,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; - // Just a dummy value - int32_t n = 128; - int64_t lda = a.stride(0); using StrideA = Stride, int64_t>; @@ -85,7 +82,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, StrideA a_stride{lda, Int<1>{}, 0}; using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; @@ -155,7 +152,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, return true; } -bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, +bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a) { if (a.dtype() == torch::kBFloat16) { return sparsify_and_compress(a_compressed, e, a); diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 8d36ece6d79c9..4537d31c54eb1 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -36,13 +36,13 @@ template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& e, - torch::Tensor const& b, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(e.dtype() == torch::kUInt8); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -72,68 +72,68 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass3xGemm8 = typename sm90_fp8_config_8::Cutlass3xGemm; - uint32_t const n = b.size(1); // Batch size - uint32_t const m = a.size(0); - uint32_t const np2 = - std::max(static_cast(64), next_pow_2(n)); // next power of 2 + uint32_t const n = bt_nzs.size(0); + uint32_t const m = a.size(0); // Batch size + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 - if (np2 <= 64) { - if (m == 28672) { + if (mp2 <= 64) { + if (n == 28672) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 4096 || m == 6144) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096 || n == 6144) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } - } else if (np2 <= 128) { - if (m == 4096) { + } else if (mp2 <= 128) { + if (n == 4096) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 28672) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 6144) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } - } else if (np2 <= 256) { - if (m == 4096) { + } else if (mp2 <= 256) { + if (n == 4096) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 28672) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 6144) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } } else { - if (m == 6144 || m == 28672) { + if (n == 6144 || n == 28672) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (m == 4096) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } } // Otherwise the default heuristic - if (np2 <= 64) { + if (mp2 <= 64) { // n in [1, 64] return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (np2 <= 128) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { // n in (64, 128] return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); - } else if (np2 <= 256) { + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 256) { // n in (128, 256] return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } else { // n in (256, inf) return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } } @@ -141,53 +141,53 @@ template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& e, - torch::Tensor const& b, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat16); - TORCH_CHECK(e.dtype() == torch::kUInt8); - TORCH_CHECK(b.dtype() == torch::kFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& e, - torch::Tensor const& b, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kBFloat16); - TORCH_CHECK(e.dtype() == torch::kUInt8); - TORCH_CHECK(b.dtype() == torch::kBFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& e, - torch::Tensor const& b, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(e.dtype() == torch::kUInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -213,23 +213,23 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, // m in [1, 32] if (is_small_n) { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } else { return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } else { // m in (128, inf) return cutlass_sparse_gemm_caller( - out, a, e, b, std::forward(args)...); + out, a, bt_nzs, bt_meta, std::forward(args)...); } } @@ -237,68 +237,68 @@ template