diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d6185e756338..c19812ab54914 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,7 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu") + "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index e652179718c95..5c1d6e3f46be0 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -241,10 +241,6 @@ struct Sm90RowOrScalarBroadcastArray { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - if (threadIdx.x ==128){ - printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), @@ -435,9 +431,6 @@ struct Sm90ColOrScalarBroadcastArray { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - // if (threadIdx.x ==128){ - // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu similarity index 90% rename from csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu rename to csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index c9d299c111304..b08d67d046643 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -5,7 +5,7 @@ #include "cutlass/cutlass.h" -// TODO let's see which of these we'll need +// TODO clean up the includes we no longer need #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -26,10 +26,6 @@ #include "common.hpp" -// get rid of these? -// #include "helper.h" -// using namespace cute; - using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 @@ -129,20 +125,6 @@ cutlass::platform::unique_ptr> make_device_ptr( return cutlass::platform::unique_ptr>(data_device); } -/////////////// -template -void print(const TupType& _tup, std::index_sequence) { - std::cout << "("; - (..., (std::cout << (I == 0 ? "" : ", ") << std::get(_tup))); - std::cout << ")\n"; -} - -template -void print(const std::tuple& _tup) { - print(_tup, std::make_index_sequence()); -} -//////////// - template void cutlass_group_gemm_caller(c10::List const& out_tensors, c10::List const& a_tensors, @@ -242,6 +224,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), @@ -255,8 +239,6 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; - // std::cout << "gemm_op.can_implement(args): " - // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); @@ -336,27 +318,6 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, using Cutlass3xGemmDefault = typename sm90_fp8_config_default< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; - // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - - // // uint32_t const m = a.size(0); - // uint32_t const mp2 = - // std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - // if (mp2 <= 64) { - // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else if (mp2 <= 128) { - // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else { - // // m in (128, inf) cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 4c909669aa5d3..445a06f57a965 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,10 +457,11 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) +# TODO add bias @pytest.mark.parametrize("num_groups", [8]) -@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) -@pytest.mark.parametrize("use_bias", [False]) # [True, False]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, @@ -479,9 +480,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, baseline_tensors = [] alignment = 16 # 128 // 8 - # For variation, each group g has dimensions + # For variation, each group has dimensions # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) - for g in range(num_groups): + for _ in range(num_groups): m_g = alignment * random.randint(1, 64) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64)