Skip to content

Commit

Permalink
Working halfway
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
  • Loading branch information
ElizaWszola committed Jan 20, 2025
1 parent f1a5666 commit 6414e31
Showing 1 changed file with 60 additions and 13 deletions.
73 changes: 60 additions & 13 deletions csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ void cutlass_group_gemm_caller(c10::List<at::Tensor> const& out_tensors,
b_scales_ptrs_host[g] =
reinterpret_cast<const ElementAccumulator*>(b_scales[g].data_ptr());

// printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g],
// c_ptrs_host[g], d_ptrs_host[g],)
int64_t m = a_tensors[g].size(0);
int64_t k = a_tensors[g].size(1);

Expand Down Expand Up @@ -348,23 +350,68 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
}
}

// For a given "a" of size [M,K] performs a permutation of the M rows based
// on the given "perm" indices.
__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr,
int const* __restrict__ perm_int_ptr,
cutlass::float_e4m3_t* __restrict__ out_ptr,
int size_m, int size_k, int block_rows) {
// TODO
}
// // For a given "a" of size [M,K] performs a permutation of the M rows based
// // on the given "perm" indices.
// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr,
// int const* __restrict__ perm_int_ptr,
// cutlass::float_e4m3_t* __restrict__ out_ptr,
// int size_m, int size_k, int block_rows) {
// int start_row = block_rows * blockIdx.x;
// int finish_row = start_row + block_rows;
// if (finish_row > size_m) {
// finish_row = size_m;
// }
// int cur_block_rows = finish_row - start_row;

// int row_stride = size_k * sizeof(cutlass::float_e4m3_t) / 16;

// auto permute_row = [&](int row) {
// int iters = size_k / blockDim.x;
// int rest = size_k % blockDim.x;

// int a_offset = perm_int_ptr[row] * row_stride;
// int out_offset = row * row_stride;

// cutlass::float_e4m3_t const* a_row_fp8 = a_ptr + a_offset;
// cutlass::float_e4m3_t* out_fp8 = out_ptr + out_offset;

// int base_k = 0;

// for (int i = 0; i < iters; i++) {
// int cur_k = base_k + threadIdx.x;
// out_fp8[cur_k] = a_row_fp8[cur_k];
// base_k += blockDim.x;
// }

// if (rest) {
// if (threadIdx.x < rest) {
// int cur_k = base_k + threadIdx.x;
// out_fp8[cur_k] = a_row_fp8[cur_k];
// }
// }
// };
// }

void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(),
(cutlass::float_e4m3_t*)a.data_ptr(),
(const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(),
topk_ids.numel());
get_a_expert_offsets<<<1, num_experts>>>(
(cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(),
(cutlass::float_e4m3_t*)a.data_ptr(),
(const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(),
topk_ids.numel());
}

// void permute_fp8_rows(torch::Tensor& a_ptr,
// torch::Tensor& perm_ptr,
// torch::Tensor& out_ptr,
// int size_m, int size_k, int topk, int block_rows) {
// permute_fp8_rows_kernel<<<blocks, num_threads, 0, stream>>>(
// (cutlass::float_e4m3_t const*)a_ptr.data_ptr(),
// (const int*)perm_ptr.data_ptr(),
// (cutlass::float_e4m3_t const*)out_ptr.data_ptr(), size_m * topk,
// size_k, block_rows);
// }

0 comments on commit 6414e31

Please sign in to comment.