diff --git a/README.md b/README.md index faef4a2..39736cc 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,27 @@ A lighweight library exposing grouped GEMM kernels in PyTorch. # Installation -Run `pip install grouped_gemm` to install the package. \ No newline at end of file +Run `pip install grouped_gemm` to install the package. + +# Compiling from source + +By default, the installed package runs in conservative (`cuBLAS`) mode: +it launches one GEMM kernel per batch element instead of using a single +grouped GEMM kernel for the whole batch. + +To enable using grouped GEMM kernels, you need to switch to the `CUTLASS` +mode by setting the `GROUPED_GEMM_CUTLASS` environment variable to `1` +when building the library. For example, to build the library in `CUTLASS` +mode for Ampere (SM 8.0), clone the repository and run the following: + +```bash +$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install . +``` + +See [this comment](https://github.com/tgale96/grouped_gemm/pull/14#issuecomment-2211362572) +for some performance measurements on A100 and H100. + +# Upcoming features + +* Running grouped GEMM kernels without GPU<->CPU synchronization points. +* Hopper-optimized grouped GEMM kernels. diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 21229a0..63b95f7 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -11,6 +11,8 @@ #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/device/gemm_grouped.h" +#include + namespace grouped_gemm { #define CUDA_CALL(code) \ @@ -30,43 +32,62 @@ namespace grouped_gemm { #define GROUPED_GEMM_STRINGIFY(x) \ GROUPED_GEMM_STRINGIFY_HELPER(x) +template +using GroupedGemmInputLayout = std::conditional_t; + +using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< + ::cutlass::arch::OpClassTensorOp, + ::cutlass::arch::Sm80, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + ::cutlass::bfloat16_t, + float +>; + // TODO(tgale): Update this for SM90 when it's supported by CUTLASS. -using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped< - // Non-transposed A operand. +template +using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + // A operand. ::cutlass::bfloat16_t, - ::cutlass::layout::RowMajor, + GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, - // Non-transposed B operand. + GroupedGemmConfig::kAlignmentA, + // B operand. ::cutlass::bfloat16_t, - ::cutlass::layout::RowMajor, + GroupedGemmInputLayout, ::cutlass::ComplexTransform::kNone, - 8, + GroupedGemmConfig::kAlignmentB, // C operand. ::cutlass::bfloat16_t, ::cutlass::layout::RowMajor, float, ::cutlass::arch::OpClassTensorOp, ::cutlass::arch::Sm80, - ::cutlass::gemm::GemmShape<128, 128, 32>, - ::cutlass::gemm::GemmShape<64, 64, 32>, - ::cutlass::gemm::GemmShape<16, 8, 16>, - ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, + GroupedGemmConfig::ThreadblockShape, + GroupedGemmConfig::WarpShape, + GroupedGemmConfig::InstructionShape, + GroupedGemmConfig::EpilogueOutputOp, // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. // This parameter is passed in at present to match the APIs of other kernels. The parameter // is unused within the kernel. ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // TODO(tgale): Experiment with GroupScheduleMode. // TODO(tgale): Tune this for SM90. - 4>::GemmKernel; -using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped; + GroupedGemmConfig::kStages>::GemmKernel; + +template +using GemmGrouped = ::cutlass::gemm::device::GemmGrouped>; -std::vector MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) { +template +std::vector MakeProblemSizes(torch::Tensor a, torch::Tensor b, torch::Tensor batch_sizes) { const size_t num_experts = batch_sizes.size(0); - const size_t k = b.size(1), n = b.size(2); + const size_t hidden_in = a.size(1), hidden_out = (trans_a || trans_b) ? b.size(1) : b.size(2); std::vector problem_sizes(num_experts); for (int i = 0; i < num_experts; ++i) { - problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr()[i], n, k); + int64_t bs = batch_sizes.data_ptr()[i]; + problem_sizes[i] = trans_a + ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, bs) + : cutlass::gemm::GemmCoord(bs, hidden_out, hidden_in); } return problem_sizes; } @@ -84,57 +105,96 @@ torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) return out; } -template +template +static void ReorderArray(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(data, data + indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + data[i] = copy.at(indices[i]); + } +} + +template typename Gemm::Arguments MakeArguments(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { - auto problem_sizes_host = MakeProblemSizes(b, batch_sizes); - - // Calculate the number of threadblocks to use and validate the result. - int64_t num_experts = problem_sizes_host.size(); + auto problem_sizes_host = MakeProblemSizes(a, b, batch_sizes); - // NOTE: This is borrowed from FasterTransformer. - int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); - if (!threadblock_count) { - TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); - } + int64_t num_experts_orig = problem_sizes_host.size(); // Create the host arrays of leading dimension data and pointer data. using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; - std::vector lda_host(num_experts), offsets_a(num_experts); - std::vector ldb_host(num_experts), offsets_b(num_experts); - std::vector ldc_host(num_experts), offsets_c(num_experts); + std::vector lda_host, ldb_host, ldc_host; int64_t elements_a = 0, elements_b = 0, elements_c = 0; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementC = typename Gemm::ElementC; - std::vector ptr_a_host(num_experts); - std::vector ptr_b_host(num_experts); - std::vector ptr_c_host(num_experts); + std::vector ptr_a_host, ptr_b_host, ptr_c_host; - for (int i = 0; i < num_experts; ++i) { - auto problem = problem_sizes_host[i]; - lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); - ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); - ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + lda_host.reserve(num_experts_orig); + ldb_host.reserve(num_experts_orig); + ldc_host.reserve(num_experts_orig); - offsets_a[i] = elements_a; - offsets_b[i] = elements_b; - offsets_c[i] = elements_c; + ptr_a_host.reserve(num_experts_orig); + ptr_b_host.reserve(num_experts_orig); + ptr_c_host.reserve(num_experts_orig); - ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i]; - ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i]; - ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i]; + // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593. + // Until a fix is available on the CUTLASS side, handle these problems by ourselves. + int64_t num_experts = 0; + for (int i = 0; i < num_experts_orig; ++i) { + auto problem = problem_sizes_host[i]; + if (problem.k() == 0) { + CUDA_CALL(cudaMemsetAsync((ElementC*)c.data_ptr() + elements_c, + 0, + problem.m() * problem.n() * sizeof(ElementC), + c10::cuda::getCurrentCUDAStream())); + } else { + lda_host.push_back(LayoutA::packed({problem.m(), problem.k()}).stride(0)); + ldb_host.push_back(LayoutB::packed({problem.k(), problem.n()}).stride(0)); + ldc_host.push_back(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + + ptr_a_host.push_back((ElementA*)a.data_ptr() + elements_a); + ptr_b_host.push_back((ElementB*)b.data_ptr() + elements_b); + ptr_c_host.push_back((ElementC*)c.data_ptr() + elements_c); + + problem_sizes_host[num_experts++] = problem; + } elements_a += problem.m() * problem.k(); elements_b += problem.k() * problem.n(); elements_c += problem.m() * problem.n(); } + problem_sizes_host.resize(num_experts); + + // Calculate the number of threadblocks to use and validate the result. + // NOTE: This is borrowed from FasterTransformer. + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); + if (!threadblock_count) { + TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); + } + + // Only sort problems when trans_a = True because only this case K are different + if (trans_a) { + std::vector indices(num_experts); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) { + return problem_sizes_host[i].k() > problem_sizes_host[j].k(); + }); + + ReorderArray(problem_sizes_host.data(), indices); + ReorderArray(lda_host.data(), indices); + ReorderArray(ldb_host.data(), indices); + ReorderArray(ldc_host.data(), indices); + ReorderArray(ptr_a_host.data(), indices); + ReorderArray(ptr_b_host.data(), indices); + ReorderArray(ptr_c_host.data(), indices); + } // Copy the problem sizes, pointers and leading dimension data to the device. torch::Tensor lda = CopyToDevice(lda_host, a.device()); @@ -162,14 +222,15 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a, return arguments; } +template torch::Tensor CutlassGroupedGemm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { - using Gemm = GemmGroupedNN; + using Gemm = GemmGrouped; Gemm gemm; - auto arguments = MakeArguments(a, b, c, batch_sizes); + auto arguments = MakeArguments(a, b, c, batch_sizes); int64_t workspace_size = gemm.get_workspace_size(arguments); auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); torch::Tensor workspace = torch::empty(workspace_size, options); @@ -302,49 +363,63 @@ void GroupedGemm(torch::Tensor a, TORCH_CHECK(a.ndimension() == 2); TORCH_CHECK(a.scalar_type() == torch::kBFloat16); - // Defer to the variable 'k' helper for the rest of the op. +#if !defined(GROUPED_GEMM_CUTLASS) if (trans_a) { + // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS + // for the rest of the op. GroupedGemmVariableK(a, b, c, batch_sizes); return; } +#endif - // We expected a CUDA tensor with three dimensions and shape - // (num_experts, hidden_in, hidden_out) for 'b'. TORCH_CHECK(b.is_cuda()); - TORCH_CHECK(b.ndimension() == 3); - TORCH_CHECK(b.scalar_type() == torch::kBFloat16); - - // Validate the contraction dimensions match. - int64_t tokens = a.size(0), num_experts = b.size(0); - int64_t hidden_in = trans_b ? b.size(2) : b.size(1); - int64_t hidden_out = trans_b ? b.size(1) : b.size(2); - TORCH_CHECK(hidden_in == a.size(1)); - - // Validate that we have one size per expert. - TORCH_CHECK(batch_sizes.size(0) == num_experts); - - // Validate the output shape. TORCH_CHECK(c.is_cuda()); - TORCH_CHECK(c.ndimension() == 2); + TORCH_CHECK(b.scalar_type() == torch::kBFloat16); TORCH_CHECK(c.scalar_type() == torch::kBFloat16); - TORCH_CHECK(c.size(0) == tokens); - TORCH_CHECK(c.size(1) == hidden_out); + + // The expected shapes of 'b' and 'c' are: + // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out) + // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out) + // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden + if (trans_a) { + TORCH_CHECK(b.ndimension() == 2); + TORCH_CHECK(c.ndimension() == 3); + TORCH_CHECK(b.size(0) == a.size(0)); + TORCH_CHECK(c.size(0) == batch_sizes.size(0)); + TORCH_CHECK(c.size(1) == a.size(1)); + TORCH_CHECK(c.size(2) == b.size(1)); + } else { + TORCH_CHECK(b.ndimension() == 3); + TORCH_CHECK(c.ndimension() == 2); + + // Validate the contraction dimensions match. + int64_t tokens = a.size(0), num_experts = b.size(0); + int64_t hidden_in = trans_b ? b.size(2) : b.size(1); + int64_t hidden_out = trans_b ? b.size(1) : b.size(2); + TORCH_CHECK(hidden_in == a.size(1)); + + // Validate that we have one size per expert. + TORCH_CHECK(batch_sizes.size(0) == num_experts); + } // NOTE: We support transposition through the 'trans_b' flag. TORCH_CHECK(a.is_contiguous()); TORCH_CHECK(b.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); - // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. -#if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 +#if !defined(GROUPED_GEMM_CUTLASS) CublasGroupedGemm(a, b, c, batch_sizes, trans_b); return; #else - // TODO(tgale): Support transposition with CUTLASS grouped GEMM. + if (trans_a) { + CutlassGroupedGemm(a, b, c, batch_sizes); + return; + } if (trans_b) { - CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + CutlassGroupedGemm(a, b, c, batch_sizes); return; } - CutlassGroupedGemm(a, b, c, batch_sizes); + CutlassGroupedGemm(a, b, c, batch_sizes); return; #endif } diff --git a/grouped_gemm/ops_test.py b/grouped_gemm/ops_test.py index b607a46..1c11be2 100644 --- a/grouped_gemm/ops_test.py +++ b/grouped_gemm/ops_test.py @@ -104,6 +104,49 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): self.assertTrue(allclose(b.grad, b_ref.grad)) +class EdgeCasesTest(unittest.TestCase): + + def testGroupedGemm_ZeroSize(self): + torch.manual_seed(0) + m = 16384 + k = 4096 + n = 14336 + num_experts = 8 + + a = randn(num_experts, m // num_experts, k).view(-1, k) + b = randn(num_experts, k, n) + batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long) + + a.requires_grad_(True) + b.requires_grad_(True) + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + out = ops.gmm(a, b, batch_sizes) + expected_out = gmm(a_ref, b_ref, batch_sizes) + self.assertTrue(allclose(out, expected_out)) + + # Check gradients. + out.sum().backward() + expected_out.sum().backward() + self.assertTrue(allclose(a.grad, a_ref.grad)) + self.assertTrue(allclose(b.grad, b_ref.grad)) + + def testGroupedGemm_ZeroK(self): + sz = 128 + total_tokens = 192 + + a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) + b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) + c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16) + batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long) + + ops.backend.gmm(a, b, batch_sizes, trans_a=True, c=c) + self.assertTrue((c[0] == 0).all()) + self.assertTrue((c[1] == 128).all()) + self.assertTrue((c[2] == 0).all()) + self.assertTrue((c[3] == 64).all()) + if __name__ == '__main__': unittest.main() diff --git a/setup.py b/setup.py index 4f9c40e..230fa25 100644 --- a/setup.py +++ b/setup.py @@ -20,9 +20,11 @@ if device_capability: nvcc_flags.extend([ f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", - f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", ]) +if os.environ.get("GROUPED_GEMM_CUTLASS", "0") == "1": + nvcc_flags.extend(["-DGROUPED_GEMM_CUTLASS"]) + ext_modules = [ CUDAExtension( "grouped_gemm_backend", @@ -50,7 +52,7 @@ setup( name="grouped_gemm", - version="0.1.4", + version="0.1.5", author="Trevor Gale", author_email="tgale@stanford.edu", description="Grouped GEMM",