Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 12, 2024
1 parent 154814f commit ac059b4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
39 changes: 21 additions & 18 deletions benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,27 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))

# cutlass sparse impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b,
torch.bfloat16))
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))

# cutlass sparse with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16,
bias))

ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))

return timers


def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
Expand Down Expand Up @@ -158,29 +158,32 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))

# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b,
torch.bfloat16))
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))

# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16))
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16))

# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16,
bias))
bench_fn(label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))

# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
bench_fn(label, sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16, bias.to(dtype=torch.float16)))

return timers

Expand Down Expand Up @@ -356,4 +359,4 @@ def to_torch_dtype(dt):
model_parser.set_defaults(func=run_model_bench)

args = parser.parse_args()
args.func(args)
args.func(args)
12 changes: 4 additions & 8 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
Expand Down Expand Up @@ -151,8 +150,7 @@ void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat16);

using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;

// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
Expand All @@ -172,8 +170,7 @@ void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kBFloat16);

using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;

// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
Expand All @@ -193,8 +190,7 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);

using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
Expand Down
13 changes: 7 additions & 6 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default {};


template <typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<half_t, OutType, Epilogue> {
Expand All @@ -196,7 +195,7 @@ struct sm90_config_default<half_t, OutType, Epilogue> {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename OutType,
Expand All @@ -208,8 +207,9 @@ struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
float>;
};

//////////////////////// Cherry-Picking Kernels ////////////////////////
Expand Down Expand Up @@ -335,8 +335,9 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand Down

0 comments on commit ac059b4

Please sign in to comment.