diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index f93ab89049360..0bbed68a71e67 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -70,19 +70,18 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, bias)) - + # cutlass sparse impl timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, - torch.bfloat16)) + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) # cutlass sparse with bias timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16, - bias)) - + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) return timers @@ -90,7 +89,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: assert dtype == torch.float8_e4m3fn - b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) @@ -158,29 +158,32 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16)) - + # cutlass impl: bf16 output timers.append( bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, - torch.bfloat16)) + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) # cutlass impl: fp16 output timers.append( bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16)) + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) # cutlass impl: bf16 output, with bias timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16, - bias)) + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) # cutlass impl: fp16 output, with bias timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16))) + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) return timers @@ -356,4 +359,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) \ No newline at end of file + args.func(args) diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 76b6b2e395c04..8d36ece6d79c9 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -45,8 +45,7 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; + typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = @@ -151,8 +150,7 @@ void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat16); using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; + typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( @@ -172,8 +170,7 @@ void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kBFloat16); using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; + typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( @@ -193,8 +190,7 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; + typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 59027d61debae..aa90388295492 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -185,7 +185,6 @@ template typename Epilogue> struct sm90_config_default {}; - template typename Epilogue> struct sm90_config_default { @@ -196,7 +195,7 @@ struct sm90_config_default { using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_sparse_3x_gemm; + KernelSchedule, EpilogueSchedule, float>; }; template { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; + cutlass_sparse_3x_gemm; }; //////////////////////// Cherry-Picking Kernels //////////////////////// @@ -335,8 +335,9 @@ struct sm90_config_default { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; + cutlass_sparse_3x_gemm; }; template