From 4c927a0a815cdaa6d5b29479d2baa1879ceb038e Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Fri, 13 Dec 2024 09:42:14 +0000 Subject: [PATCH] Address reviews; one compression test left to pass --- .../cutlass_benchmarks/sparse_benchmarks.py | 10 ++-- csrc/core/math.hpp | 7 +++ csrc/cutlass_extensions/common.hpp | 19 ++------ csrc/ops.h | 2 +- .../cutlass_w8a8/scaled_mm_c2x.cuh | 1 + .../cutlass_w8a8/scaled_mm_c3x.cu | 1 + csrc/sparse/cutlass/sparse_compressor.cu | 34 ++----------- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 48 +++++++------------ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 36 +++++++------- csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 10 ++-- vllm/_custom_ops.py | 13 +++-- 11 files changed, 73 insertions(+), 108 deletions(-) create mode 100644 csrc/core/math.hpp diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index eec6e6134a0cf..b48fa501942d4 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -46,15 +46,14 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): print("Incorrect results") print(out) print(out_ref) - else: - print("Correct results") timers = [] # pytorch impl - bfloat16 @@ -105,15 +104,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): print("Incorrect results") print(out) print(out_ref) - else: - print("Correct results") timers = [] diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 5c1098971b462..11c8486647c5e 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -11,25 +11,16 @@ #define CUTLASS_CHECK(status) \ { \ TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ + cutlassGetStatusString(status)); \ } -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - /** * Panic wrapper for unwinding CUDA runtime errors */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ } inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { diff --git a/csrc/ops.h b/csrc/ops.h index d43f495aabd80..d1b2e212f8a44 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -162,7 +162,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, c10::optional const& bias); bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, - torch::Tensor const& a); + torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 6e72aff89f2e4..75681f7f37820 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,6 +21,7 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "core/math.hpp" #include "cutlass_extensions/common.hpp" // clang-format on diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index d118f0e070a63..8190277997161 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -24,6 +24,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "core/math.hpp" #include "cutlass_extensions/common.hpp" // clang-format on diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu index 30b78054f300e..d8f1e5e852a40 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -1,46 +1,22 @@ +// clang-format will break include orders +// clang-format off #include -#include - -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" +#include "sparse_scaled_mm_c3x.cuh" #include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/detail/dependent_false.hpp" - -#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" -#include "cutlass_extensions/common.hpp" #include "cutlass/transform/device/transform_universal_adapter.hpp" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" - #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" - -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/tensor_ref.h" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "sparse_scaled_mm_c3x.cuh" +// clang-format on using namespace cute; using namespace vllm; @@ -153,7 +129,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, } bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, - torch::Tensor const& a) { + torch::Tensor const& a) { if (a.dtype() == torch::kBFloat16) { return sparsify_and_compress(a_compressed, e, a); } else if (a.dtype() == torch::kFloat16) { diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 4537d31c54eb1..2001fca9ed1a2 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -3,32 +3,10 @@ #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 - -#include - -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - +#include "sparse_scaled_mm_c3x.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "cutlass_extensions/common.hpp" // clang-format on - #include "sparse_scaled_mm_c3x.cuh" - using namespace cute; using namespace vllm; @@ -247,11 +225,13 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_int8_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_int8_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } } else if (a.dtype() == torch::kFloat8_e4m3fn) { TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); @@ -259,12 +239,14 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_fp8_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_fp8_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } } else if (a.dtype() == torch::kFloat16) { TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); @@ -272,12 +254,14 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_fp16_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_fp16_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } } else { // a.dtype() == torch::kBFloat16 TORCH_CHECK(a.dtype() == torch::kBFloat16); @@ -286,12 +270,14 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_bf16_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_bf16_dispatch( - out, a, bt_nzs, bt_meta, std::forward(epilogue_args)...); + out, a, bt_nzs, bt_meta, + std::forward(epilogue_args)...); } } } diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 1cef3d9c0de8e..9267b87cd3cc7 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -1,29 +1,24 @@ +// clang-format will break include orders +// clang-format off #include #include #include -#include -#include -#include - #include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "core/math.hpp" #include "cutlass_extensions/common.hpp" #include "cutlass_extensions/torch_utils.hpp" +// clang-format on using namespace cute; @@ -94,8 +89,9 @@ struct cutlass_sparse_3x_gemm { typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD, ElementD, - LayoutD_Transpose, AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD, + ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule, + EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -121,8 +117,7 @@ struct cutlass_sparse_3x_gemm { }; template -void cutlass_sparse_gemm_caller(torch::Tensor& out, - torch::Tensor const& a, +void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, EpilogueArgs&&... epilogue_params) { @@ -145,7 +140,8 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{(int) bt_nzs.size(0), (int) size<0>(layout_A), (int) size<1>(layout_A), 1}; + typename GemmKernel::ProblemShape prob_shape{ + (int)bt_nzs.size(0), (int)size<0>(layout_A), (int)size<1>(layout_A), 1}; using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; @@ -198,7 +194,7 @@ struct sm90_config_default { using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; + cutlass_sparse_3x_gemm; }; //////////////////////// Cherry-Picking Kernels //////////////////////// @@ -337,8 +334,9 @@ struct sm90_config_default { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; + cutlass_sparse_3x_gemm; }; template = 90) { - cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales, bias); + cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales, + bias); return; } #endif diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a14b1f3bbf45b..dc22d90bd0a5c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -567,20 +567,25 @@ def cutlass_sparse_compress(a: torch.Tensor) \ assert (a.dtype in [ torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 ]) + assert (a.is_contiguous()) # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 elemsPerMetaElem = 4 m = a.shape[0] k = a.shape[1] + assert (k % 2 == 0) a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device) a_meta = torch.empty((m, k // 2 // elemsPerMetaElem), - dtype=torch.uint8, - device=a.device) + dtype=torch.uint8, + device=a.device) if not (torch.ops._C.cutlass_sparse_compress(a_nzs, a_meta, a)): raise ValueError + assert (a_nzs.is_contiguous()) + assert (a_meta.is_contiguous()) + return a_nzs, a_meta @@ -624,8 +629,8 @@ def cutlass_scaled_sparse_mm( n = bt_nzs.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, scale_b, - bias) + torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, + scale_b, bias) return out