Skip to content

Commit

Permalink
working fused op
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
  • Loading branch information
ElizaWszola committed Jan 29, 2025
1 parent 1ea7874 commit d608164
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 167 deletions.
12 changes: 7 additions & 5 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// b_offsets) -> ()");
// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

ops.def(
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets);
// ops.def(
// "compute_expert_offsets(Tensor! trg_a_ptrs,"
// " Tensor! a, Tensor topk_ids,"
// " Tensor! expert_offsets, SymInt num_experts) ->
// ()");
// ops.impl("compute_expert_offsets", torch::kCUDA,
// &compute_expert_offsets);

// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
Expand Down
8 changes: 5 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes);

void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a,
const torch::Tensor& topk_ids,
void compute_expert_offsets(const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const int64_t num_experts, const int64_t n,
const int64_t k);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
Expand Down
71 changes: 37 additions & 34 deletions csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;

int groups = (int)expert_offsets.size(0) - 1;
int groups = (int)expert_offsets.size(0);
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);

Expand All @@ -146,27 +146,22 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,

auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(a_tensors.data_ptr()),
options_int);
torch::Tensor a_ptrs_base = torch::full(
groups, reinterpret_cast<int64_t>(a_tensors.data_ptr()), options_int);
torch::Tensor out_ptrs_base = torch::full(
{groups + 1}, reinterpret_cast<int64_t>(out_tensors.data_ptr()),
options_int);
torch::Tensor b_ptrs_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(b_tensors.data_ptr()),
options_int);
torch::Tensor a_scales_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(a_scales.data_ptr()),
options_int);
torch::Tensor b_scales_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(b_scales.data_ptr()),
options_int);

torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1),
b_single_size, options_int);
torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int);
groups, reinterpret_cast<int64_t>(out_tensors.data_ptr()), options_int);
torch::Tensor b_ptrs_base = torch::full(
groups, reinterpret_cast<int64_t>(b_tensors.data_ptr()), options_int);
torch::Tensor a_scales_base = torch::full(
groups, reinterpret_cast<int64_t>(a_scales.data_ptr()), options_int);
torch::Tensor b_scales_base = torch::full(
groups, reinterpret_cast<int64_t>(b_scales.data_ptr()), options_int);

torch::Tensor b_offsets =
torch::arange(0, b_single_size * groups, b_single_size, options_int);
torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int);
torch::Tensor b_scales_offsets = torch::arange(
0, b_scale_single_size * (groups + 1), b_scale_single_size, options_int);
0, b_scale_single_size * groups, b_scale_single_size, options_int);

torch::Tensor a_ptrs = a_ptrs_base.add(
expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1));
Expand All @@ -189,6 +184,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
std::vector<StrideB> b_stride_host(groups);
std::vector<StrideC> c_stride_host(groups);

// TODO pass strides?
for (int32_t g = 0; g < groups; ++g) {
int64_t lda = a_tensors.stride(0); // row-major (m x k)
int64_t ldb = a_tensors.stride(0); // column-major (k x n)
Expand Down Expand Up @@ -325,26 +321,31 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors,
problem_sizes);
}

__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
cutlass::float_e4m3_t* base_a_ptr,
const int* __restrict__ topk_ids,
int64_t* expert_offsets, int topk_length) {
__global__ void get_a_expert_offsets(const int* __restrict__ topk_ids,
int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2, int topk_length,
int n, int k) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;

int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
problem_sizes1[expert_id * 3] = occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
__syncthreads();

if (threadIdx.x == 0) {
int64_t tot_offset = 0;
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
trg_a_ptrs[i] = base_a_ptr + tot_offset;
tot_offset += expert_offsets[i + 1];
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
}
}
Expand Down Expand Up @@ -394,14 +395,16 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
// };
// }

void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a,
const torch::Tensor& topk_ids,
void compute_expert_offsets_caller(const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const int64_t num_experts, const int64_t n,
const int64_t k) {
get_a_expert_offsets<<<1, num_experts>>>(
(cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(),
(cutlass::float_e4m3_t*)a.data_ptr(), (const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(), topk_ids.numel());
(const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(),
(int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(),
topk_ids.numel(), n, k);
}

// void permute_fp8_rows(torch::Tensor& a_ptr,
Expand Down
29 changes: 15 additions & 14 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes);


void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
void compute_expert_offsets_caller(const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const int64_t num_experts, const int64_t n,
const int64_t k);

#endif

Expand Down Expand Up @@ -166,17 +166,18 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes) {
cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales,
b_scales, expert_offsets, problem_sizes);
cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes);
}

void compute_expert_offsets(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets,
num_experts);
void compute_expert_offsets(const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const int64_t num_experts, const int64_t n,
const int64_t k) {
compute_expert_offsets_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, num_experts, n, k);
}

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
Expand Down
6 changes: 3 additions & 3 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

ops.def(
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
"compute_expert_offsets(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" SymInt num_experts, SymInt n, SymInt k) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets);

// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
Expand Down
10 changes: 5 additions & 5 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,14 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
a_tensors_stacked = torch.empty((expert_offsets[num_groups], k_g),
device=device,
dtype=torch.float8_e4m3fn)
b_tensors_stacked = torch.empty((n_g * num_groups, k_g),
b_tensors_stacked = torch.empty((num_groups, n_g, k_g),
device=device,
dtype=torch.float8_e4m3fn)
for g in range(num_groups):
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
1]] = a_tensors[g]
b_tensors_stacked[g * n_g:(g + 1) * n_g, :] = b_tensors[g].t()
b_tensors_stacked = b_tensors_stacked.t()
b_tensors_stacked[g] = b_tensors[g].t()
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)

a_scales_tensors_stacked = torch.empty(
(expert_offsets[num_groups] if per_act_token else num_groups, 1),
Expand All @@ -538,8 +538,8 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked,
a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets,
problem_sizes)
b_scales_tensors_stacked,
expert_offsets[:-1], problem_sizes)

# Validate each group's result against the baseline
for g in range(num_groups):
Expand Down
57 changes: 29 additions & 28 deletions tests/kernels/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk, cutlass_moe
from vllm.platforms import current_platform
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe,
fused_topk)
from vllm.platforms import current_platform

NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]


@pytest.mark.parametrize("m", [16, 32, 64, 224])
@pytest.mark.parametrize("m", [2, 16, 32, 64, 224])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
Expand All @@ -36,39 +37,39 @@ def test_cutlass_moe(

a_q, a_scale = ops.scaled_fp8_quant(a)

w1_qs = []
w2_qs = []
w1_scales = []
w2_scales = []
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32)

for expert in range(e):
w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert])
w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert])
w1_qs.append(w1_q.t())
w2_qs.append(w2_q.t())
w1_scales.append(w1_scale.reshape((1, 1)))
w2_scales.append(w2_scale.reshape((1, 1)))

score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)

w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)
a_d = (a_q.float() * a_scale).half()
w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half()
w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half()

w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
for expert in range(e):
w1_d[expert] = (w1_qs[expert].t().float() *
w1_scales[expert]).half()
w2_d[expert] = (w2_qs[expert].t().float() *
w2_scales[expert]).half()
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)

torch_output = torch_moe(a_d, w1_d, w2_d, score, topk)
cutlass_output = cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales,
w2_scales, topk_weights, topk_ids, m, n,
k)
cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale,
w2_scale, topk_weights, topk_ids, m, n, k,
e)

# print(torch_output)
# print(cutlass_output)
# print(torch_output / cutlass_output)
print(torch_output)
print(cutlass_output)
print("*")

torch.testing.assert_close(torch_output,
cutlass_output,
Expand Down
Loading

0 comments on commit d608164

Please sign in to comment.