diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu index ba067a82da1d1..d551a71e51601 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -78,9 +78,9 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, c3x::ScaledEpilogue>::Cutlass3xGemm, typename std::conditional< std::is_same_v, - typename sm90_fp16_config_default::Cutlass3xGemm, + typename sm90_fp16_config_default< + cutlass::half_t, cutlass::half_t, + c3x::ScaledEpilogue>::Cutlass3xGemm, typename sm90_bf16_config_default< cutlass::bfloat16_t, cutlass::half_t, c3x::ScaledEpilogue>::Cutlass3xGemm>::type>::type>::type; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 2eaffb11a7b40..ea191015b4159 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -27,7 +27,7 @@ #include "cutlass_extensions/common.hpp" // clang-format on -#include "sparse_scaled_mm_c3x.cuh" + #include "sparse_scaled_mm_c3x.cuh" using namespace cute; using namespace vllm; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 91aa7da4b9807..a5925c715940b 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -26,8 +26,8 @@ using namespace cute; /* - This file defines sparse quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm90a (Hopper) or later. + This file defines sparse quantized GEMM operations using the CUTLASS 3.x API, + for NVIDIA GPUs with sm90a (Hopper) or later. Epilogue functions can be defined to post-process the output before it is written to GPU memory. @@ -192,7 +192,7 @@ struct sm90_fp16_config_default { using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; //////////////////////// Cherry-Picking Kernels //////////////////////// @@ -220,7 +220,7 @@ struct sm90_fp8_config_1 { using ClusterShape = Shape<_8, _1, _1>; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; //////////////////////////////////////////////////////////////////////// @@ -334,7 +334,7 @@ struct sm90_fp8_config_default { using ClusterShape = Shape<_1, _2, _1>; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template ; + KernelSchedule, EpilogueSchedule, float, + TileSchedule>; }; template ; + KernelSchedule, EpilogueSchedule, float, + TileSchedule>; }; template ; + KernelSchedule, EpilogueSchedule, float, + TileSchedule>; }; template ; + KernelSchedule, EpilogueSchedule, float, + TileSchedule>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, int32_t>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, int32_t>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, int32_t>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, int32_t>; }; template ; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, int32_t>; }; } // namespace \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9573d31e46445..034b3e9493736 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -534,7 +534,9 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, def cutlass_compress_entry(a: torch.Tensor) \ -> Tuple[torch.Tensor, torch.Tensor]: - assert (a.dtype is [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16]) + assert (a.dtype in [ + torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 + ]) # e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4 elemsPerElemE = 4