diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index f7a0410..63b95f7 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -35,6 +35,15 @@ namespace grouped_gemm { template using GroupedGemmInputLayout = std::conditional_t; +using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< + ::cutlass::arch::OpClassTensorOp, + ::cutlass::arch::Sm80, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + float +>; + // TODO(tgale): Update this for SM90 when it's supported by CUTLASS. template using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< @@ -42,29 +51,29 @@ using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ::cutlass::bfloat16_t, GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, + GroupedGemmConfig::kAlignmentA, // B operand. ::cutlass::bfloat16_t, GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, + GroupedGemmConfig::kAlignmentB, // C operand. ::cutlass::bfloat16_t, ::cutlass::layout::RowMajor, float, ::cutlass::arch::OpClassTensorOp, ::cutlass::arch::Sm80, - ::cutlass::gemm::GemmShape<128, 128, 32>, - ::cutlass::gemm::GemmShape<64, 64, 32>, - ::cutlass::gemm::GemmShape<16, 8, 16>, - ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, + GroupedGemmConfig::ThreadblockShape, + GroupedGemmConfig::WarpShape, + GroupedGemmConfig::InstructionShape, + GroupedGemmConfig::EpilogueOutputOp, // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. // This parameter is passed in at present to match the APIs of other kernels. The parameter // is unused within the kernel. ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // TODO(tgale): Experiment with GroupScheduleMode. // TODO(tgale): Tune this for SM90. - 4>::GemmKernel; + GroupedGemmConfig::kStages>::GemmKernel; template using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>;