Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use CUTLASS for both trans_a and trans_b on Ampere #14

Merged
merged 9 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,27 @@ A lighweight library exposing grouped GEMM kernels in PyTorch.

# Installation

Run `pip install grouped_gemm` to install the package.
Run `pip install grouped_gemm` to install the package.

# Compiling from source

By default, the installed package runs in conservative (`cuBLAS`) mode:
it launches one GEMM kernel per batch element instead of using a single
grouped GEMM kernel for the whole batch.

To enable using grouped GEMM kernels, you need to switch to the `CUTLASS`
mode by setting the `GROUPED_GEMM_CUTLASS` environment variable to `1`
when building the library. For example, to build the library in `CUTLASS`
mode for Ampere (SM 8.0), clone the repository and run the following:

```bash
$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .
```

See [this comment](https://github.com/tgale96/grouped_gemm/pull/14#issuecomment-2211362572)
for some performance measurements on A100 and H100.

# Upcoming features

* Running grouped GEMM kernels without GPU<->CPU synchronization points.
* Hopper-optimized grouped GEMM kernels.
213 changes: 144 additions & 69 deletions csrc/grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"

#include <type_traits>

namespace grouped_gemm {

#define CUDA_CALL(code) \
Expand All @@ -30,43 +32,62 @@ namespace grouped_gemm {
#define GROUPED_GEMM_STRINGIFY(x) \
GROUPED_GEMM_STRINGIFY_HELPER(x)

template <bool trans>
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;

using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
float
>;

// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped<
// Non-transposed A operand.
template <bool trans_a, bool trans_b>
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
// A operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
GroupedGemmInputLayout<trans_a>,
::cutlass::ComplexTransform::kNone,
8,
// Non-transposed B operand.
GroupedGemmConfig::kAlignmentA,
// B operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
GroupedGemmInputLayout<trans_b>,
::cutlass::ComplexTransform::kNone,
8,
GroupedGemmConfig::kAlignmentB,
// C operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
float,
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
::cutlass::gemm::GemmShape<128, 128, 32>,
::cutlass::gemm::GemmShape<64, 64, 32>,
::cutlass::gemm::GemmShape<16, 8, 16>,
::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>,
GroupedGemmConfig::ThreadblockShape,
GroupedGemmConfig::WarpShape,
GroupedGemmConfig::InstructionShape,
GroupedGemmConfig::EpilogueOutputOp,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
// TODO(tgale): Experiment with GroupScheduleMode.
// TODO(tgale): Tune this for SM90.
4>::GemmKernel;
using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernelNN>;
GroupedGemmConfig::kStages>::GemmKernel;

template <bool trans_a, bool trans_b>
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;

std::vector<cutlass::gemm::GemmCoord> MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) {
template <bool trans_a, bool trans_b>
std::vector<cutlass::gemm::GemmCoord> MakeProblemSizes(torch::Tensor a, torch::Tensor b, torch::Tensor batch_sizes) {
const size_t num_experts = batch_sizes.size(0);
const size_t k = b.size(1), n = b.size(2);
const size_t hidden_in = a.size(1), hidden_out = (trans_a || trans_b) ? b.size(1) : b.size(2);
std::vector<cutlass::gemm::GemmCoord> problem_sizes(num_experts);
for (int i = 0; i < num_experts; ++i) {
problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr<int64_t>()[i], n, k);
int64_t bs = batch_sizes.data_ptr<int64_t>()[i];
problem_sizes[i] = trans_a
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, bs)
: cutlass::gemm::GemmCoord(bs, hidden_out, hidden_in);
}
return problem_sizes;
}
Expand All @@ -84,57 +105,96 @@ torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device)
return out;
}

template <typename Gemm>
template <typename T>
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(data, data + indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
data[i] = copy.at(indices[i]);
}
}

template <typename Gemm, bool trans_a, bool trans_b>
typename Gemm::Arguments MakeArguments(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
auto problem_sizes_host = MakeProblemSizes(b, batch_sizes);

// Calculate the number of threadblocks to use and validate the result.
int64_t num_experts = problem_sizes_host.size();
auto problem_sizes_host = MakeProblemSizes<trans_a, trans_b>(a, b, batch_sizes);

// NOTE: This is borrowed from FasterTransformer.
int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts);
if (!threadblock_count) {
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
}
int64_t num_experts_orig = problem_sizes_host.size();

// Create the host arrays of leading dimension data and pointer data.
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;

std::vector<int64_t> lda_host(num_experts), offsets_a(num_experts);
std::vector<int64_t> ldb_host(num_experts), offsets_b(num_experts);
std::vector<int64_t> ldc_host(num_experts), offsets_c(num_experts);
std::vector<int64_t> lda_host, ldb_host, ldc_host;
int64_t elements_a = 0, elements_b = 0, elements_c = 0;

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
std::vector<ElementA *> ptr_a_host(num_experts);
std::vector<ElementB *> ptr_b_host(num_experts);
std::vector<ElementC *> ptr_c_host(num_experts);
std::vector<ElementA *> ptr_a_host, ptr_b_host, ptr_c_host;

for (int i = 0; i < num_experts; ++i) {
auto problem = problem_sizes_host[i];
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
lda_host.reserve(num_experts_orig);
ldb_host.reserve(num_experts_orig);
ldc_host.reserve(num_experts_orig);

offsets_a[i] = elements_a;
offsets_b[i] = elements_b;
offsets_c[i] = elements_c;
ptr_a_host.reserve(num_experts_orig);
ptr_b_host.reserve(num_experts_orig);
ptr_c_host.reserve(num_experts_orig);

ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i];
ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i];
ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i];
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
// Until a fix is available on the CUTLASS side, handle these problems by ourselves.
int64_t num_experts = 0;
for (int i = 0; i < num_experts_orig; ++i) {
auto problem = problem_sizes_host[i];
if (problem.k() == 0) {
CUDA_CALL(cudaMemsetAsync((ElementC*)c.data_ptr() + elements_c,
0,
problem.m() * problem.n() * sizeof(ElementC),
c10::cuda::getCurrentCUDAStream()));
} else {
lda_host.push_back(LayoutA::packed({problem.m(), problem.k()}).stride(0));
ldb_host.push_back(LayoutB::packed({problem.k(), problem.n()}).stride(0));
ldc_host.push_back(LayoutC::packed({problem.m(), problem.n()}).stride(0));

ptr_a_host.push_back((ElementA*)a.data_ptr() + elements_a);
ptr_b_host.push_back((ElementB*)b.data_ptr() + elements_b);
ptr_c_host.push_back((ElementC*)c.data_ptr() + elements_c);

problem_sizes_host[num_experts++] = problem;
}

elements_a += problem.m() * problem.k();
elements_b += problem.k() * problem.n();
elements_c += problem.m() * problem.n();
}
problem_sizes_host.resize(num_experts);

// Calculate the number of threadblocks to use and validate the result.
// NOTE: This is borrowed from FasterTransformer.
int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts);
if (!threadblock_count) {
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
}

// Only sort problems when trans_a = True because only this case K are different
if (trans_a) {
std::vector<size_t> indices(num_experts);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
return problem_sizes_host[i].k() > problem_sizes_host[j].k();
});

ReorderArray(problem_sizes_host.data(), indices);
ReorderArray(lda_host.data(), indices);
ReorderArray(ldb_host.data(), indices);
ReorderArray(ldc_host.data(), indices);
ReorderArray(ptr_a_host.data(), indices);
ReorderArray(ptr_b_host.data(), indices);
ReorderArray(ptr_c_host.data(), indices);
}

// Copy the problem sizes, pointers and leading dimension data to the device.
torch::Tensor lda = CopyToDevice(lda_host, a.device());
Expand Down Expand Up @@ -162,14 +222,15 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a,
return arguments;
}

template <bool trans_a, bool trans_b>
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
using Gemm = GemmGroupedNN;
using Gemm = GemmGrouped<trans_a, trans_b>;
Gemm gemm;

auto arguments = MakeArguments<Gemm>(a, b, c, batch_sizes);
auto arguments = MakeArguments<Gemm, trans_a, trans_b>(a, b, c, batch_sizes);
int64_t workspace_size = gemm.get_workspace_size(arguments);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
torch::Tensor workspace = torch::empty(workspace_size, options);
Expand Down Expand Up @@ -302,49 +363,63 @@ void GroupedGemm(torch::Tensor a,
TORCH_CHECK(a.ndimension() == 2);
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);

// Defer to the variable 'k' helper for the rest of the op.
#if !defined(GROUPED_GEMM_CUTLASS)
if (trans_a) {
// If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
// for the rest of the op.
GroupedGemmVariableK(a, b, c, batch_sizes);
return;
}
#endif

// We expected a CUDA tensor with three dimensions and shape
// (num_experts, hidden_in, hidden_out) for 'b'.
TORCH_CHECK(b.is_cuda());
TORCH_CHECK(b.ndimension() == 3);
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);

// Validate the contraction dimensions match.
int64_t tokens = a.size(0), num_experts = b.size(0);
int64_t hidden_in = trans_b ? b.size(2) : b.size(1);
int64_t hidden_out = trans_b ? b.size(1) : b.size(2);
TORCH_CHECK(hidden_in == a.size(1));

// Validate that we have one size per expert.
TORCH_CHECK(batch_sizes.size(0) == num_experts);

// Validate the output shape.
TORCH_CHECK(c.is_cuda());
TORCH_CHECK(c.ndimension() == 2);
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.size(0) == tokens);
TORCH_CHECK(c.size(1) == hidden_out);

// The expected shapes of 'b' and 'c' are:
// * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
// * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
// * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
if (trans_a) {
TORCH_CHECK(b.ndimension() == 2);
TORCH_CHECK(c.ndimension() == 3);
TORCH_CHECK(b.size(0) == a.size(0));
TORCH_CHECK(c.size(0) == batch_sizes.size(0));
TORCH_CHECK(c.size(1) == a.size(1));
TORCH_CHECK(c.size(2) == b.size(1));
} else {
TORCH_CHECK(b.ndimension() == 3);
TORCH_CHECK(c.ndimension() == 2);

// Validate the contraction dimensions match.
int64_t tokens = a.size(0), num_experts = b.size(0);
int64_t hidden_in = trans_b ? b.size(2) : b.size(1);
int64_t hidden_out = trans_b ? b.size(1) : b.size(2);
TORCH_CHECK(hidden_in == a.size(1));

// Validate that we have one size per expert.
TORCH_CHECK(batch_sizes.size(0) == num_experts);
}

// NOTE: We support transposition through the 'trans_b' flag.
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(b.is_contiguous());
TORCH_CHECK(c.is_contiguous());

// NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm.
#if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80
#if !defined(GROUPED_GEMM_CUTLASS)
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
return;
#else
// TODO(tgale): Support transposition with CUTLASS grouped GEMM.
if (trans_a) {
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes);
return;
}
if (trans_b) {
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes);
return;
}
CutlassGroupedGemm(a, b, c, batch_sizes);
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes);
return;
#endif
}
Expand Down
43 changes: 43 additions & 0 deletions grouped_gemm/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,49 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b):
self.assertTrue(allclose(b.grad, b_ref.grad))


class EdgeCasesTest(unittest.TestCase):

def testGroupedGemm_ZeroSize(self):
torch.manual_seed(0)
m = 16384
k = 4096
n = 14336
num_experts = 8

a = randn(num_experts, m // num_experts, k).view(-1, k)
b = randn(num_experts, k, n)
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)

a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)

out = ops.gmm(a, b, batch_sizes)
expected_out = gmm(a_ref, b_ref, batch_sizes)
self.assertTrue(allclose(out, expected_out))

# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad, a_ref.grad))
self.assertTrue(allclose(b.grad, b_ref.grad))

def testGroupedGemm_ZeroK(self):
sz = 128
total_tokens = 192

a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)

ops.backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
self.assertTrue((c[0] == 0).all())
self.assertTrue((c[1] == 128).all())
self.assertTrue((c[2] == 0).all())
self.assertTrue((c[3] == 64).all())


if __name__ == '__main__':
unittest.main()
Loading