Skip to content

Commit

Permalink
Use the default GEMM shapes provided by CUTLASS instead of hardcoding…
Browse files Browse the repository at this point in the history
… ours (make CUTLASS match cuBLAS on Ampere)
  • Loading branch information
Ivan Komarov committed Jul 5, 2024
1 parent 3be87fb commit 223cd55
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions csrc/grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,45 @@ namespace grouped_gemm {
template <bool trans>
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;

using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
float
>;

// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
template <bool trans_a, bool trans_b>
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
// A operand.
::cutlass::bfloat16_t,
GroupedGemmInputLayout<trans_a>,
::cutlass::ComplexTransform::kNone,
8,
GroupedGemmConfig::kAlignmentA,
// B operand.
::cutlass::bfloat16_t,
GroupedGemmInputLayout<trans_b>,
::cutlass::ComplexTransform::kNone,
8,
GroupedGemmConfig::kAlignmentB,
// C operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
float,
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
::cutlass::gemm::GemmShape<128, 128, 32>,
::cutlass::gemm::GemmShape<64, 64, 32>,
::cutlass::gemm::GemmShape<16, 8, 16>,
::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>,
GroupedGemmConfig::ThreadblockShape,
GroupedGemmConfig::WarpShape,
GroupedGemmConfig::InstructionShape,
GroupedGemmConfig::EpilogueOutputOp,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
// TODO(tgale): Experiment with GroupScheduleMode.
// TODO(tgale): Tune this for SM90.
4>::GemmKernel;
GroupedGemmConfig::kStages>::GemmKernel;

template <bool trans_a, bool trans_b>
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
Expand Down

1 comment on commit 223cd55

@chenhongyu2048
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I have some doubts about this, is the default shape provided by CUTLASS usually stronger than the hand-coded shape?

Please sign in to comment.