Skip to content

Commit

Permalink
Small cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
  • Loading branch information
ElizaWszola committed Dec 17, 2024
1 parent 6ed63f2 commit e2b1fc0
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 55 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu")
"csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,6 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));

if (threadIdx.x ==128){
printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l);
}

Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
Expand Down Expand Up @@ -435,9 +431,6 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;

// if (threadIdx.x ==128){
// printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l);
// }
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "cutlass/cutlass.h"

// TODO let's see which of these we'll need
// TODO clean up the includes we no longer need

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
Expand All @@ -26,10 +26,6 @@

#include "common.hpp"

// get rid of these?
// #include "helper.h"
// using namespace cute;

using namespace cute;

#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Expand Down Expand Up @@ -129,20 +125,6 @@ cutlass::platform::unique_ptr<T, ItemDeleter<T>> make_device_ptr(
return cutlass::platform::unique_ptr<T, ItemDeleter<T>>(data_device);
}

///////////////
template <class TupType, size_t... I>
void print(const TupType& _tup, std::index_sequence<I...>) {
std::cout << "(";
(..., (std::cout << (I == 0 ? "" : ", ") << std::get<I>(_tup)));
std::cout << ")\n";
}

template <class... T>
void print(const std::tuple<T...>& _tup) {
print(_tup, std::make_index_sequence<sizeof...(T)>());
}
////////////

template <typename Gemm>
void cutlass_group_gemm_caller(c10::List<at::Tensor> const& out_tensors,
c10::List<at::Tensor> const& a_tensors,
Expand Down Expand Up @@ -242,6 +224,8 @@ void cutlass_group_gemm_caller(c10::List<at::Tensor> const& out_tensors,
typename GemmKernel::MainloopArguments mainloop_args{
a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(),
b_stride_ptr.get()};
// Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(),
Expand All @@ -255,8 +239,6 @@ void cutlass_group_gemm_caller(c10::List<at::Tensor> const& out_tensors,

using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
// std::cout << "gemm_op.can_implement(args): "
// << (int)gemm_op.can_implement(args) << std::endl;
CUTLASS_CHECK(gemm_op.can_implement(args));

size_t workspace_size = gemm_op.get_workspace_size(args);
Expand Down Expand Up @@ -336,27 +318,6 @@ void cutlass_grouped_mm_sm90(c10::List<at::Tensor> const& out_tensors,
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
ElementAB_Type, ElementC_Type,
vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
// using Cutlass3xGemmM64 =
// typename sm90_fp8_config_M64<ElementAB_Type, ElementC_Type,
// vllm::c3x::ScaledEpilogue>::Cutlass3xGemm;
// using Cutlass3xGemmM128 =
// typename sm90_fp8_config_M128<ElementAB_Type, ElementC_Type,
// vllm::c3x::ScaledEpilogue>::Cutlass3xGemm;

// // uint32_t const m = a.size(0);
// uint32_t const mp2 =
// std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2

// if (mp2 <= 64) {
// // m in [1, 64]
// cutlass_group_gemm_caller<Cutlass3xGemmM64>(out, a, b, a_scales,
// b_scales);
// } else if (mp2 <= 128) {
// // m in (64, 128]
// cutlass_group_gemm_caller<Cutlass3xGemmM128>(out, a, b, a_scales,
// b_scales);
// } else {
// // m in (128, inf)
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales);
}
11 changes: 6 additions & 5 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,11 @@ def test_cutlass_support_opcheck():
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))


# TODO add bias
@pytest.mark.parametrize("num_groups", [8])
@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False])
@pytest.mark.parametrize("use_bias", [False]) # [True, False])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
Expand All @@ -479,9 +480,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
baseline_tensors = []

alignment = 16 # 128 // 8
# For variation, each group g has dimensions
# For variation, each group has dimensions
# (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1))
for g in range(num_groups):
for _ in range(num_groups):
m_g = alignment * random.randint(1, 64)
n_g = alignment * random.randint(1, 64)
k_g = alignment * random.randint(1, 64)
Expand Down

0 comments on commit e2b1fc0

Please sign in to comment.