From b2ba0f992f239ebbf2282f886e12dd8c06ee3562 Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Wed, 25 Dec 2024 02:46:45 -0800 Subject: [PATCH] fix pipeline. --- .vscode/settings.json | 3 +- benchmarks/cpp/flashattention/copy.cuh | 155 +++++++++++++++---- benchmarks/cpp/flashattention/cutlass_fa.cuh | 32 ++-- benchmarks/cpp/flashattention/main.cu | 8 +- 4 files changed, 155 insertions(+), 43 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index e654000..26d7cd7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,7 +6,8 @@ "span": "cpp", "bitset": "cpp", "initializer_list": "cpp", - "utility": "cpp" + "utility": "cpp", + "*.tcc": "cpp" }, "gotoSymbolStack.currentStackPosition": 0, "gotoSymbolStack.maxStackPosition": 0, diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index 238dbba..cb5e199 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -214,19 +214,24 @@ class G2SCopyV { int num_stage; }; -template +template class S2RPipelineQK { public: - DEVICE S2RPipelineQK(SQTensor& sQ, RQTensor& rQ, SKTensor& sK, RKTensor& rK, + DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view, + RQCopyView& rQ_copy_view, SKTensor& sK, + RKMmaView& rK_mma_view, RKCopyView& rK_copy_view, RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, TiledMma tiled_mma, int sQ_stride, int sK_stride, int num_stage = 2) : sQ(sQ), - rQ(rQ), + rQ_mma_view(rQ_mma_view), + rQ_copy_view(rQ_copy_view), sK(sK), - rK(rK), + rK_mma_view(rK_mma_view), + rK_copy_view(rK_copy_view), acc(acc), copy_q(copy_q), copy_k(copy_k), @@ -237,11 +242,89 @@ class S2RPipelineQK { cur_iter(0), cur_iter_sq(0) {} + DEVICE void print_rQ() { + if (thread0()) { + print(rQ_mma_view), print("\n"); + print(rQ_copy_view), print("\n"); + } + } + + DEVICE void prologue() { + cur_iter = 0; + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + cur_iter++; + } + + DEVICE void body() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq++; + } + + DEVICE void epilogue() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + + sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq); + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq = 0; + } + private: SQTensor& sQ; - RQTensor& rQ; + RQMmaView& rQ_mma_view; + RQCopyView& rQ_copy_view; SKTensor& sK; - RKTensor& rK; + RKMmaView& rK_mma_view; + RKCopyView& rK_copy_view; RAccTensor& acc; TiledCopyQ copy_q; TiledCopyK copy_k; @@ -305,13 +388,18 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride, return copy_v; } -template -DEVICE auto make_s2r_qk(SQTensor sQ, SKTensor sK, RegAcc acc, int sQ_stride, - int sK_stride, SmemCopyAtom copy_atom = SmemCopyAtom{}, +template +DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, + SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc, + int sQ_stride, int sK_stride, + SmemCopyAtom copy_atom = SmemCopyAtom{}, TiledMma tiled_mma = TiledMma{}) { int tid = threadIdx.x; + auto sQ_ = make_tensor(make_smem_ptr(sQ_ptr), sQ_layout); + auto sK_ = make_tensor(make_smem_ptr(sK_ptr), sK_layout); + auto thr_mma = tiled_mma.get_thread_slice(tid); auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma); @@ -319,27 +407,34 @@ DEVICE auto make_s2r_qk(SQTensor sQ, SKTensor sK, RegAcc acc, int sQ_stride, auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); - auto rQ_org = thr_mma.partition_fragment_A(sQ); - auto rK_org = thr_mma.partition_fragment_B(sK); + auto sQ = s2r_thr_copy_q.partition_S(sQ_); + auto sK = s2r_thr_copy_k.partition_S(sK_); - auto rQ = s2r_thr_copy_q.retile_D(rQ_org); - auto rK = s2r_thr_copy_k.retile_D(rK_org); - // auto rAcc = get_acc(rQ), size<1>(rK)>(tiled_mma); + // Thread partition for mma. + auto rQ_mma = thr_mma.partition_fragment_A(sQ_); + auto rK_mma = thr_mma.partition_fragment_B(sK_); - if (thread0()) { - printf("thr_mma: \n"); - print(thr_mma), print("\n"); - printf("s2r_copy_q: \n"); - print(s2r_copy_q), print("\n"); - printf("rQ_org: \n"); - print(rQ_org), print("\n"); - printf("rQ: \n"); - print(rQ), print("\n"); - } + // Thread partition for shared to register copy. + auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); + auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); + // auto rAcc = get_acc(rQ), size<1>(rK)>(tiled_mma); - detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ, sK, rK, acc, s2r_copy_q, - s2r_copy_k, tiled_mma, sQ_stride, - sK_stride); + // if (thread0()) { + // printf("thr_mma: \n"); + // print(thr_mma), print("\n"); + // printf("s2r_copy_q: \n"); + // print(s2r_copy_q), print("\n"); + // printf("rQ_org: \n"); + // print(rQ_org), print("\n"); + // printf("rQ: \n"); + // print(rQ), print("\n"); + // } + + detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma, + rK_copy, acc, s2r_copy_q, s2r_copy_k, + tiled_mma, sQ_stride, sK_stride); + + return s2r_pipeline_qk; } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 316f3fb..0c7d575 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -95,15 +95,13 @@ struct FATraits : public Base { template + const int kTK, const int kTP, const int Nthreads, const int kStagesQK, + const int kStageV> __global__ void __launch_bounds__(Nthreads) fa_kernel(const Element* dQ, const Element* dK, const Element* dV, Element* dO) { // constexpr float softmax_scale = 1.250000e-01f; - // Q, K: [batch, head, length, hidden_qk] - // V, O: [batch, head, length, hidden_v] - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; auto* buf = reinterpret_cast(buf_); @@ -114,8 +112,8 @@ __global__ void __launch_bounds__(Nthreads) dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP; Element* sQ_ptr = reinterpret_cast(buf); - Element* sK_ptr = sQ_ptr + kTM * kTK; - Element* sV_ptr = sK_ptr + kTN * kTK; + Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; + Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; // Element* sO_ptr = sQ_ptr; typename KeTraits::TiledMma mma; @@ -140,11 +138,18 @@ __global__ void __launch_bounds__(Nthreads) g2s_copy_qk.print_gQ_data(0); #endif - auto sQ = g2s_copy_qk.get_sQ(); - auto sK = g2s_copy_qk.get_sK(); auto acc0 = get_acc(mma); - make_s2r_qk(sQ, sK, acc0, kTK, kTK, typename KeTraits::SmemCopyAtom{}, mma); + if (thread0()) { + printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); + } + + auto s2r_pipeline_qk = + make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, + typename KeTraits::SmemLayoutK{}, acc0, kTK, kTK, + typename KeTraits::SmemCopyAtom{}, mma); + s2r_pipeline_qk.print_rQ(); // Issue global to shared memory copy before the main loop. g2s_copy_qk.prologue(); @@ -159,12 +164,19 @@ __global__ void __launch_bounds__(Nthreads) __syncthreads(); g2s_copy_qk.body(); // Load data from shared memory into register and issue MMA. + s2r_pipeline_qk.body(); } cp_async_wait_flash<0>(); __syncthreads(); - // g2s_copy_qk.print_sQ_data(0); g2s_copy_v.prologue(); + s2r_pipeline_qk.epilogue(); + + // Print acc0 data. + if (thread0()) { + printf("acc0: \n"); + print(acc0), print("\n"); + } } } diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 32b81bc..847711a 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -33,6 +33,8 @@ void run(bool check = true) { static constexpr int kBatch = 1; static constexpr int kThreads = 128; + static constexpr int kStagesQK = 2; + static constexpr int kStagesV = 2; static_assert(kK == kTK, "The current implementation requires kTK == K for now."); @@ -118,8 +120,10 @@ void run(bool check = true) { benchmarks::cutlass_wrapper::FATraits; - auto fa_kernel = benchmarks::cutlass_wrapper::fa_kernel< - cutlass::half_t, Traits, kM, kN, kK, kP, kTM, kTN, kTK, kTP, kThreads>; + auto fa_kernel = + benchmarks::cutlass_wrapper::fa_kernel; if (shm_size > 48 * 1024) { cudaFuncSetAttribute(