Skip to content

Commit

Permalink
Grouped gemm working
Browse files Browse the repository at this point in the history
Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com>
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
  • Loading branch information
ElizaWszola and LucasWilkinson committed Dec 17, 2024
1 parent c570c69 commit 6ed63f2
Show file tree
Hide file tree
Showing 8 changed files with 704 additions and 254 deletions.
464 changes: 464 additions & 0 deletions csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;

if (threadIdx.x ==128){
printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l);
}
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Expand Down
64 changes: 64 additions & 0 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"

/*
This file defines custom epilogues for fusing channel scales, token scales,
Expand Down Expand Up @@ -45,6 +46,16 @@ struct ScaledEpilogueBase {
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;

template <typename T>
using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;

template <typename T>
using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;

// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
Expand Down Expand Up @@ -72,6 +83,15 @@ struct ScaledEpilogueBase {
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}

template <typename Descriptor, typename T>
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
using Arguments = typename Descriptor::Arguments;
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
return Arguments{data_ptr, do_broadcast};
}

};

/*
Expand Down Expand Up @@ -312,4 +332,48 @@ struct ScaledEpilogueBiasAzpToken
}
};

/*
TODO document
This is an epilogue with ptr arrays to a_scales and b_scales
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueArray
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;

using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;

static ArgumentType prepare_args(const float* const* a_scales_ptr,
const float* const* b_scales_ptr,
bool a_col_broadcast,
bool b_row_broadcast) {
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(a_scales_ptr, a_col_broadcast);
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(b_scales_ptr, b_row_broadcast);

typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};

}; // namespace vllm::c3x
12 changes: 5 additions & 7 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,11 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& problem_sizes,
torch::Tensor const& out_offsets,
torch::Tensor const& a_offsets,
torch::Tensor const& b_offsets);
void cutlass_grouped_mm(c10::List<at::Tensor> const& out_tensors,
c10::List<at::Tensor> const& a_tensors,
c10::List<at::Tensor> const& b_tensors,
c10::List<at::Tensor> const& a_scales,
c10::List<at::Tensor> const& b_scales);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
Expand Down
Loading

0 comments on commit 6ed63f2

Please sign in to comment.