From 763b38df88fe77e3397687689a35770f8fe10dfd Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Mon, 24 Jun 2024 20:52:10 +0000 Subject: [PATCH 1/9] Use CUTLASS for both trans_a and trans_b on Ampere --- csrc/grouped_gemm.cu | 132 +++++++++++++++++++++++++++------------ grouped_gemm/ops_test.py | 28 +++++++++ 2 files changed, 120 insertions(+), 40 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 21229a0..16202f3 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -11,6 +11,8 @@ #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/device/gemm_grouped.h" +#include + namespace grouped_gemm { #define CUDA_CALL(code) \ @@ -30,16 +32,20 @@ namespace grouped_gemm { #define GROUPED_GEMM_STRINGIFY(x) \ GROUPED_GEMM_STRINGIFY_HELPER(x) +template +using GroupedGemmInputLayout = std::conditional_t; + // TODO(tgale): Update this for SM90 when it's supported by CUTLASS. -using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped< - // Non-transposed A operand. +template +using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + // A operand. ::cutlass::bfloat16_t, - ::cutlass::layout::RowMajor, + GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, 8, - // Non-transposed B operand. + // B operand. ::cutlass::bfloat16_t, - ::cutlass::layout::RowMajor, + GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, 8, // C operand. @@ -59,14 +65,20 @@ using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped< // TODO(tgale): Experiment with GroupScheduleMode. // TODO(tgale): Tune this for SM90. 4>::GemmKernel; -using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped; -std::vector MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) { +template +using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>; + +template +std::vector MakeProblemSizes(torch::Tensor a, torch::Tensor b, torch::Tensor batch_sizes) { const size_t num_experts = batch_sizes.size(0); - const size_t k = b.size(1), n = b.size(2); + const size_t hidden_in = a.size(1), hidden_out = (trans_a || trans_b) ? b.size(1) : b.size(2); std::vector problem_sizes(num_experts); for (int i = 0; i < num_experts; ++i) { - problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr()[i], n, k); + int64_t bs = batch_sizes.data_ptr()[i]; + problem_sizes[i] = trans_a + ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, bs) + : cutlass::gemm::GemmCoord(bs, hidden_out, hidden_in); } return problem_sizes; } @@ -84,12 +96,23 @@ torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) return out; } -template +template +static void ReorderArray(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); +} + +template typename Gemm::Arguments MakeArguments(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { - auto problem_sizes_host = MakeProblemSizes(b, batch_sizes); + auto problem_sizes_host = MakeProblemSizes(a, b, batch_sizes); // Calculate the number of threadblocks to use and validate the result. int64_t num_experts = problem_sizes_host.size(); @@ -135,6 +158,22 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a, elements_b += problem.k() * problem.n(); elements_c += problem.m() * problem.n(); } + // Only sort problems when trans_a = True because only this case K are different + if (trans_a) { + std::vector indices(num_experts); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) { + return problem_sizes_host[i].k() > problem_sizes_host[j].k(); + }); + + ReorderArray(problem_sizes_host.data(), indices); + ReorderArray(lda_host.data(), indices); + ReorderArray(ldb_host.data(), indices); + ReorderArray(ldc_host.data(), indices); + ReorderArray(ptr_a_host.data(), indices); + ReorderArray(ptr_b_host.data(), indices); + ReorderArray(ptr_c_host.data(), indices); + } // Copy the problem sizes, pointers and leading dimension data to the device. torch::Tensor lda = CopyToDevice(lda_host, a.device()); @@ -162,14 +201,15 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a, return arguments; } +template torch::Tensor CutlassGroupedGemm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { - using Gemm = GemmGroupedNN; + using Gemm = GemmGrouped; Gemm gemm; - auto arguments = MakeArguments(a, b, c, batch_sizes); + auto arguments = MakeArguments(a, b, c, batch_sizes); int64_t workspace_size = gemm.get_workspace_size(arguments); auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); torch::Tensor workspace = torch::empty(workspace_size, options); @@ -302,49 +342,61 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(a.ndimension() == 2); TORCH_CHECK(a.scalar_type() == torch::kBFloat16); - // Defer to the variable 'k' helper for the rest of the op. - if (trans_a) { - GroupedGemmVariableK(a, b, c, batch_sizes); - return; - } - - // We expected a CUDA tensor with three dimensions and shape - // (num_experts, hidden_in, hidden_out) for 'b'. TORCH_CHECK(b.is_cuda()); - TORCH_CHECK(b.ndimension() == 3); - TORCH_CHECK(b.scalar_type() == torch::kBFloat16); - - // Validate the contraction dimensions match. - int64_t tokens = a.size(0), num_experts = b.size(0); - int64_t hidden_in = trans_b ? b.size(2) : b.size(1); - int64_t hidden_out = trans_b ? b.size(1) : b.size(2); - TORCH_CHECK(hidden_in == a.size(1)); - - // Validate that we have one size per expert. - TORCH_CHECK(batch_sizes.size(0) == num_experts); - - // Validate the output shape. TORCH_CHECK(c.is_cuda()); - TORCH_CHECK(c.ndimension() == 2); + TORCH_CHECK(b.scalar_type() == torch::kBFloat16); TORCH_CHECK(c.scalar_type() == torch::kBFloat16); - TORCH_CHECK(c.size(0) == tokens); - TORCH_CHECK(c.size(1) == hidden_out); + + // The expected shapes of 'b' and 'c' are: + // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out) + // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out) + // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden + if (trans_a) { + TORCH_CHECK(b.ndimension() == 2); + TORCH_CHECK(c.ndimension() == 3); + TORCH_CHECK(b.size(0) == a.size(0)); + TORCH_CHECK(c.size(0) == batch_sizes.size(0)); + TORCH_CHECK(c.size(1) == a.size(1)); + TORCH_CHECK(c.size(2) == b.size(1)); + } else { + TORCH_CHECK(b.ndimension() == 3); + TORCH_CHECK(c.ndimension() == 2); + + // Validate the contraction dimensions match. + int64_t tokens = a.size(0), num_experts = b.size(0); + int64_t hidden_in = trans_b ? b.size(2) : b.size(1); + int64_t hidden_out = trans_b ? b.size(1) : b.size(2); + TORCH_CHECK(hidden_in == a.size(1)); + + // Validate that we have one size per expert. + TORCH_CHECK(batch_sizes.size(0) == num_experts); + } // NOTE: We support transposition through the 'trans_b' flag. TORCH_CHECK(a.is_contiguous()); TORCH_CHECK(b.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. #if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 + // Defer to the variable 'k' helper for the rest of the op. + if (trans_a) { + GroupedGemmVariableK(a, b, c, batch_sizes); + return; + } CublasGroupedGemm(a, b, c, batch_sizes, trans_b); return; #else - // TODO(tgale): Support transposition with CUTLASS grouped GEMM. + if (trans_a) { + CutlassGroupedGemm(a, b, c, batch_sizes); + return; + } if (trans_b) { - CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + CutlassGroupedGemm(a, b, c, batch_sizes); return; } - CutlassGroupedGemm(a, b, c, batch_sizes); + CutlassGroupedGemm(a, b, c, batch_sizes); return; #endif } diff --git a/grouped_gemm/ops_test.py b/grouped_gemm/ops_test.py index b607a46..55805db 100644 --- a/grouped_gemm/ops_test.py +++ b/grouped_gemm/ops_test.py @@ -104,6 +104,34 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): self.assertTrue(allclose(b.grad, b_ref.grad)) +class EdgeCasesTest(unittest.TestCase): + + def testGroupedGemm_ZeroSize(self): + torch.manual_seed(0) + m = 16384 + k = 4096 + n = 14336 + num_experts = 8 + + a = randn(num_experts, m // num_experts, k).view(-1, k) + b = randn(num_experts, k, n) + batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long) + + a.requires_grad_(True) + b.requires_grad_(True) + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + out = ops.gmm(a, b, batch_sizes) + expected_out = gmm(a_ref, b_ref, batch_sizes) + self.assertTrue(allclose(out, expected_out)) + + # Check gradients. + out.sum().backward() + expected_out.sum().backward() + self.assertTrue(allclose(a.grad, a_ref.grad)) + self.assertTrue(allclose(b.grad, b_ref.grad)) + if __name__ == '__main__': unittest.main() From 398709f0645418ff9dffe1070a222d789ca8e3f1 Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Mon, 1 Jul 2024 18:06:45 +0000 Subject: [PATCH 2/9] Work around the CUTLASS bug for `k=0` problems --- csrc/grouped_gemm.cu | 64 ++++++++++++++++++++++++---------------- grouped_gemm/ops_test.py | 15 ++++++++++ 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 16202f3..d7520e4 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -114,50 +114,64 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a, torch::Tensor batch_sizes) { auto problem_sizes_host = MakeProblemSizes(a, b, batch_sizes); - // Calculate the number of threadblocks to use and validate the result. - int64_t num_experts = problem_sizes_host.size(); - - // NOTE: This is borrowed from FasterTransformer. - int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); - if (!threadblock_count) { - TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); - } + int64_t num_experts_orig = problem_sizes_host.size(); // Create the host arrays of leading dimension data and pointer data. using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; - std::vector lda_host(num_experts), offsets_a(num_experts); - std::vector ldb_host(num_experts), offsets_b(num_experts); - std::vector ldc_host(num_experts), offsets_c(num_experts); + std::vector lda_host, ldb_host, ldc_host; int64_t elements_a = 0, elements_b = 0, elements_c = 0; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementC = typename Gemm::ElementC; - std::vector ptr_a_host(num_experts); - std::vector ptr_b_host(num_experts); - std::vector ptr_c_host(num_experts); + std::vector ptr_a_host, ptr_b_host, ptr_c_host; - for (int i = 0; i < num_experts; ++i) { - auto problem = problem_sizes_host[i]; - lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); - ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); - ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + lda_host.reserve(num_experts_orig); + ldb_host.reserve(num_experts_orig); + ldc_host.reserve(num_experts_orig); - offsets_a[i] = elements_a; - offsets_b[i] = elements_b; - offsets_c[i] = elements_c; + ptr_a_host.reserve(num_experts_orig); + ptr_b_host.reserve(num_experts_orig); + ptr_c_host.reserve(num_experts_orig); - ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i]; - ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i]; - ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i]; + // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593. + // Until a fix is available on the CUTLASS side, handle these problems by ourselves. + int64_t num_experts = 0; + for (int i = 0; i < num_experts_orig; ++i) { + auto problem = problem_sizes_host[i]; + if (problem.k() == 0) { + CUDA_CALL(cudaMemsetAsync((ElementC*)c.data_ptr() + elements_c, + 0, + problem.m() * problem.n() * sizeof(ElementC), + c10::cuda::getCurrentCUDAStream())); + } else { + lda_host.push_back(LayoutA::packed({problem.m(), problem.k()}).stride(0)); + ldb_host.push_back(LayoutB::packed({problem.k(), problem.n()}).stride(0)); + ldc_host.push_back(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + + ptr_a_host.push_back((ElementA*)a.data_ptr() + elements_a); + ptr_b_host.push_back((ElementB*)b.data_ptr() + elements_b); + ptr_c_host.push_back((ElementC*)c.data_ptr() + elements_c); + + problem_sizes_host[num_experts++] = problem; + } elements_a += problem.m() * problem.k(); elements_b += problem.k() * problem.n(); elements_c += problem.m() * problem.n(); } + problem_sizes_host.resize(num_experts); + + // Calculate the number of threadblocks to use and validate the result. + // NOTE: This is borrowed from FasterTransformer. + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); + if (!threadblock_count) { + TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); + } + // Only sort problems when trans_a = True because only this case K are different if (trans_a) { std::vector indices(num_experts); diff --git a/grouped_gemm/ops_test.py b/grouped_gemm/ops_test.py index 55805db..1c11be2 100644 --- a/grouped_gemm/ops_test.py +++ b/grouped_gemm/ops_test.py @@ -132,6 +132,21 @@ def testGroupedGemm_ZeroSize(self): self.assertTrue(allclose(a.grad, a_ref.grad)) self.assertTrue(allclose(b.grad, b_ref.grad)) + def testGroupedGemm_ZeroK(self): + sz = 128 + total_tokens = 192 + + a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) + b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) + c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16) + batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long) + + ops.backend.gmm(a, b, batch_sizes, trans_a=True, c=c) + self.assertTrue((c[0] == 0).all()) + self.assertTrue((c[1] == 128).all()) + self.assertTrue((c[2] == 0).all()) + self.assertTrue((c[3] == 64).all()) + if __name__ == '__main__': unittest.main() From 5b64e00ab3c9b4ecdceebd619e159c910339510e Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Mon, 1 Jul 2024 18:15:53 +0000 Subject: [PATCH 3/9] Remove an unnecessary memcpy() --- csrc/grouped_gemm.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index d7520e4..5ae9403 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -99,12 +99,10 @@ torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) template static void ReorderArray(T* data, const std::vector& indices) { // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); + std::vector copy(data, data + indices.size()); for (size_t i = 0; i < indices.size(); ++i) { - copy.at(i) = data[indices[i]]; + data[i] = copy.at(indices[i]); } - - memcpy(data, copy.data(), indices.size() * sizeof(T)); } template From 2285bb4170dbea77b8f48e04f1c021a676a844aa Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Mon, 1 Jul 2024 19:30:40 +0000 Subject: [PATCH 4/9] Gate using CUTLASS for transposed cases behind a compile-time flag --- csrc/grouped_gemm.cu | 22 +++++++++++++++++----- setup.py | 3 +++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 5ae9403..92887bc 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -354,6 +354,15 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(a.ndimension() == 2); TORCH_CHECK(a.scalar_type() == torch::kBFloat16); +#if !defined(GROUPED_GEMM_FULL_CUTLASS) + if (trans_a) { + // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS + // for the rest of the op. + GroupedGemmVariableK(a, b, c, batch_sizes); + return; + } +#endif + TORCH_CHECK(b.is_cuda()); TORCH_CHECK(c.is_cuda()); TORCH_CHECK(b.scalar_type() == torch::kBFloat16); @@ -392,14 +401,10 @@ void GroupedGemm(torch::Tensor a, // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. #if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 - // Defer to the variable 'k' helper for the rest of the op. - if (trans_a) { - GroupedGemmVariableK(a, b, c, batch_sizes); - return; - } CublasGroupedGemm(a, b, c, batch_sizes, trans_b); return; #else +#if defined(GROUPED_GEMM_FULL_CUTLASS) if (trans_a) { CutlassGroupedGemm(a, b, c, batch_sizes); return; @@ -408,6 +413,13 @@ void GroupedGemm(torch::Tensor a, CutlassGroupedGemm(a, b, c, batch_sizes); return; } +#else + TORCH_CHECK(!trans_a, "The trans_a case should have been handled earlier"); + if (trans_b) { + CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + return; + } +#endif CutlassGroupedGemm(a, b, c, batch_sizes); return; #endif diff --git a/setup.py b/setup.py index 4f9c40e..9087b4e 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,9 @@ f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", ]) +if os.environ.get("GROUPED_GEMM_FULL_CUTLASS", "0") == "1": + nvcc_flags.extend(["-DGROUPED_GEMM_FULL_CUTLASS"]) + ext_modules = [ CUDAExtension( "grouped_gemm_backend", From 3be87fb350d7dca487e2688d86eda99ae369a31e Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Wed, 3 Jul 2024 20:43:23 +0000 Subject: [PATCH 5/9] Simplify #ifdefs --- csrc/grouped_gemm.cu | 14 ++------------ setup.py | 17 ++--------------- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 92887bc..f7a0410 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -354,7 +354,7 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(a.ndimension() == 2); TORCH_CHECK(a.scalar_type() == torch::kBFloat16); -#if !defined(GROUPED_GEMM_FULL_CUTLASS) +#if !defined(GROUPED_GEMM_CUTLASS) if (trans_a) { // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS // for the rest of the op. @@ -398,13 +398,10 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(b.is_contiguous()); TORCH_CHECK(c.is_contiguous()); - - // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. -#if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 +#if !defined(GROUPED_GEMM_CUTLASS) CublasGroupedGemm(a, b, c, batch_sizes, trans_b); return; #else -#if defined(GROUPED_GEMM_FULL_CUTLASS) if (trans_a) { CutlassGroupedGemm(a, b, c, batch_sizes); return; @@ -413,13 +410,6 @@ void GroupedGemm(torch::Tensor a, CutlassGroupedGemm(a, b, c, batch_sizes); return; } -#else - TORCH_CHECK(!trans_a, "The trans_a case should have been handled earlier"); - if (trans_b) { - CublasGroupedGemm(a, b, c, batch_sizes, trans_b); - return; - } -#endif CutlassGroupedGemm(a, b, c, batch_sizes); return; #endif diff --git a/setup.py b/setup.py index 9087b4e..20183f5 100644 --- a/setup.py +++ b/setup.py @@ -4,27 +4,14 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if os.environ.get("TORCH_CUDA_ARCH_LIST"): - # Let PyTorch builder to choose device to target for. - device_capability = "" -else: - device_capability = torch.cuda.get_device_capability() - device_capability = f"{device_capability[0]}{device_capability[1]}" - cwd = Path(os.path.dirname(os.path.abspath(__file__))) nvcc_flags = [ "-std=c++17", # NOTE: CUTLASS requires c++17 ] -if device_capability: - nvcc_flags.extend([ - f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", - f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", - ]) - -if os.environ.get("GROUPED_GEMM_FULL_CUTLASS", "0") == "1": - nvcc_flags.extend(["-DGROUPED_GEMM_FULL_CUTLASS"]) +if os.environ.get("GROUPED_GEMM_CUTLASS", "0") == "1": + nvcc_flags.extend(["-DGROUPED_GEMM_CUTLASS"]) ext_modules = [ CUDAExtension( From 223cd55f1a7dfecad250c613da5fb7f795588a34 Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Fri, 5 Jul 2024 18:20:00 +0000 Subject: [PATCH 6/9] Use the default GEMM shapes provided by CUTLASS instead of hardcoding ours (make CUTLASS match cuBLAS on Ampere) --- csrc/grouped_gemm.cu | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index f7a0410..63b95f7 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -35,6 +35,15 @@ namespace grouped_gemm { template using GroupedGemmInputLayout = std::conditional_t; +using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< + ::cutlass::arch::OpClassTensorOp, + ::cutlass::arch::Sm80, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + float +>; + // TODO(tgale): Update this for SM90 when it's supported by CUTLASS. template using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< @@ -42,29 +51,29 @@ using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ::cutlass::bfloat16_t, GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, + GroupedGemmConfig::kAlignmentA, // B operand. ::cutlass::bfloat16_t, GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, + GroupedGemmConfig::kAlignmentB, // C operand. ::cutlass::bfloat16_t, ::cutlass::layout::RowMajor, float, ::cutlass::arch::OpClassTensorOp, ::cutlass::arch::Sm80, - ::cutlass::gemm::GemmShape<128, 128, 32>, - ::cutlass::gemm::GemmShape<64, 64, 32>, - ::cutlass::gemm::GemmShape<16, 8, 16>, - ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, + GroupedGemmConfig::ThreadblockShape, + GroupedGemmConfig::WarpShape, + GroupedGemmConfig::InstructionShape, + GroupedGemmConfig::EpilogueOutputOp, // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. // This parameter is passed in at present to match the APIs of other kernels. The parameter // is unused within the kernel. ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // TODO(tgale): Experiment with GroupScheduleMode. // TODO(tgale): Tune this for SM90. - 4>::GemmKernel; + GroupedGemmConfig::kStages>::GemmKernel; template using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>; From 482c3e0c334ce575ce9c4b1c1bda0a8431e23bd7 Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Tue, 9 Jul 2024 14:54:07 +0000 Subject: [PATCH 7/9] Update the README --- README.md | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index faef4a2..39736cc 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,27 @@ A lighweight library exposing grouped GEMM kernels in PyTorch. # Installation -Run `pip install grouped_gemm` to install the package. \ No newline at end of file +Run `pip install grouped_gemm` to install the package. + +# Compiling from source + +By default, the installed package runs in conservative (`cuBLAS`) mode: +it launches one GEMM kernel per batch element instead of using a single +grouped GEMM kernel for the whole batch. + +To enable using grouped GEMM kernels, you need to switch to the `CUTLASS` +mode by setting the `GROUPED_GEMM_CUTLASS` environment variable to `1` +when building the library. For example, to build the library in `CUTLASS` +mode for Ampere (SM 8.0), clone the repository and run the following: + +```bash +$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install . +``` + +See [this comment](https://github.com/tgale96/grouped_gemm/pull/14#issuecomment-2211362572) +for some performance measurements on A100 and H100. + +# Upcoming features + +* Running grouped GEMM kernels without GPU<->CPU synchronization points. +* Hopper-optimized grouped GEMM kernels. From d6cc388e59db59a50fc237cef36c4de243cd5aec Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Tue, 16 Jul 2024 16:57:49 +0000 Subject: [PATCH 8/9] Re-introduce the logic selecting `device_capability` When `TORCH_CUDA_ARCH_LIST` is not set explicitly, use the compute capability of the current GPU to build against to avoid redundant compilations. --- setup.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/setup.py b/setup.py index 20183f5..e859abe 100644 --- a/setup.py +++ b/setup.py @@ -4,12 +4,24 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension +if os.environ.get("TORCH_CUDA_ARCH_LIST"): + # Let PyTorch builder to choose device to target for. + device_capability = "" +else: + device_capability = torch.cuda.get_device_capability() + device_capability = f"{device_capability[0]}{device_capability[1]}" + cwd = Path(os.path.dirname(os.path.abspath(__file__))) nvcc_flags = [ "-std=c++17", # NOTE: CUTLASS requires c++17 ] +if device_capability: + nvcc_flags.extend([ + f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", + ]) + if os.environ.get("GROUPED_GEMM_CUTLASS", "0") == "1": nvcc_flags.extend(["-DGROUPED_GEMM_CUTLASS"]) From d17ce71d4d3fd83a73e89832185c6f07c6e93a8a Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Tue, 16 Jul 2024 16:59:51 +0000 Subject: [PATCH 9/9] Bump the library version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e859abe..230fa25 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ setup( name="grouped_gemm", - version="0.1.4", + version="0.1.5", author="Trevor Gale", author_email="tgale@stanford.edu", description="Grouped GEMM",