diff --git a/.vscode/settings.json b/.vscode/settings.json index 26d7cd7..f082a27 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,7 +7,11 @@ "bitset": "cpp", "initializer_list": "cpp", "utility": "cpp", - "*.tcc": "cpp" + "*.tcc": "cpp", + "chrono": "cpp", + "random": "cpp", + "limits": "cpp", + "semaphore": "cpp" }, "gotoSymbolStack.currentStackPosition": 0, "gotoSymbolStack.maxStackPosition": 0, diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh new file mode 100644 index 0000000..f212fac --- /dev/null +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include "cuda_utils.cuh" + +#include +#include +#include + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +template +CUTE_DEVICE auto convert_type(cute::Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using namespace cute; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + auto l = logical_divide(rowcol_layout, + Shape>>{}); + + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), + get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), get<1>(get<1>(get<1>(l)))); +} + +DEVICE auto convert_layout_C_Aregs() { + using namespace cute; + auto layout_s = Layout, _2, _16>>{}; + auto l = logical_divide(layout_s, Shape{}); + + return make_layout( + make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))), + get<1>(l), get<1>(get<2>(l))); +} + +template +DEVICE auto convert_layout_scores(LayoutType layout_s) { + using namespace cute; + static_assert(decltype(size<0>(layout_s))::value == 4); + static_assert(decltype(rank(layout_s))::value == 3); + + auto l = logical_divide(layout_s, Shape<_2>{}); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), + make_layout(get<0>(get<0>(l)), get<2>(l))); +} + +template +DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) { + using namespace cute; + + auto l = logical_divide(layout_s, Shape>{}); + return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l))); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks \ No newline at end of file diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index cb5e199..b8b81b5 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -40,6 +40,8 @@ class G2SCopyQK { DEVICE void print_gQ() { if (thread0()) { print(gQ), print("\n"); + printf("gQ size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(gQ), (int)size<1>(gQ), (int)size<2>(gQ)); } } @@ -336,6 +338,56 @@ class S2RPipelineQK { int cur_iter_sq; }; +template +class S2RPipelineQK { + public: + DEVICE S2RPipelineQK(SVTensor& sV, RVMmaView& rV_mma_view, + RVCopyView& rV_copy_view, RegAcc& acc, + TiledCopy tiled_copy, TiledMma, tiled_mma, + int sV_stride, int num_stage = 2) + : sV(sV), + rV_mma_view(rV_mma_view), + rV_copy_view(rV_copy_view), + acc(acc), + tiled_copy(tiled_copy), + sV_stride(sV_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sv(0) {} + + template + DEVICE void prologue(RegValue& value) { + cur_iter = 0; + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + cur_iter++; + } + + private: + SVTensor& sV; + RVMmaView& rV_mma_view; + RVCopyView& rV_copy_view; + RegAcc& acc; + TiledCopy tiled_copy; + TiledMma tiled_mma; + int sV_stride; + int num_stage; + int cur_iter; + int cur_iter_sv; +} + } // namespace detail template (gQ): %d, size<1>(gQ): %d\n", (int)size<0>(gQ), + (int)size<1>(gQ)); + } + auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{}); auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{}); @@ -417,18 +476,6 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, // Thread partition for shared to register copy. auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); - // auto rAcc = get_acc(rQ), size<1>(rK)>(tiled_mma); - - // if (thread0()) { - // printf("thr_mma: \n"); - // print(thr_mma), print("\n"); - // printf("s2r_copy_q: \n"); - // print(s2r_copy_q), print("\n"); - // printf("rQ_org: \n"); - // print(rQ_org), print("\n"); - // printf("rQ: \n"); - // print(rQ), print("\n"); - // } detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma, rK_copy, acc, s2r_copy_q, s2r_copy_k, @@ -437,5 +484,27 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, return s2r_pipeline_qk; } +template +DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + int tid = threadIdx.x; + + auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid); + + auto sV = s2r_thr_copy_v.partition_S(sV_); + + auto rV_mma = thr_mma.partition_fragment_B(sV_); + auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks + } // namespace cutlass_wrapper } // namespace benchmarks \ No newline at end of file diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 0c7d575..ae9661a 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -3,10 +3,12 @@ #pragma once +#include "convert.cuh" #include "copy.cuh" #include "cuda_utils.cuh" #include "cutlass/copy.cuh" #include "cutlass/traits_base.cuh" +#include "reduce.cuh" namespace benchmarks { namespace cutlass_wrapper { @@ -17,8 +19,9 @@ using namespace cute; /// @tparam Element_ template > + const int kTP, const int kWarpPerRow, const int kWarpPerCol, + const int kThreads, const int SmemKAtom = 64, const int kSwizzle = 3, + typename Base = AccessBase> struct FATraits : public Base { // Q: [kM, kK] --> [length, hidden_qk] // K: [kN, kK] --> [length, hidden_qk] @@ -27,10 +30,6 @@ struct FATraits : public Base { // assert(kM == kN) using Element = Element_; - // TODO: fix the hardcode. - static constexpr int kWarpPerRow = 1; - static constexpr int kWarpPerCol = 1; - // Declare global to shared memory copy layout. using GmemLayoutQ = Layout, Int>, Stride, _1>>; using GmemLayoutK = Layout, Int>, Stride, _1>>; @@ -100,7 +99,7 @@ template (buf_); @@ -139,6 +138,10 @@ __global__ void __launch_bounds__(Nthreads) #endif auto acc0 = get_acc(mma); + auto acco = get_acc(mma); + + auto m_new = make_tensor(Shape(acc0)>>{}); + auto lse_new = make_fragment_like(m_new); if (thread0()) { printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", @@ -154,9 +157,13 @@ __global__ void __launch_bounds__(Nthreads) // Issue global to shared memory copy before the main loop. g2s_copy_qk.prologue(); - for (int n = 0; n < kN; n += kTN) { - int split_k = kK / kTK - 1; + fill(lse_new, 0.0f); + fill(m_new, -INFINITY); + clear(acco); + int split_n = kN / kTN; + for (int n = 0; n < split_n; ++n) { + int split_k = kK / kTK - 1; // Pipeline for (int k = 0; k < split_k; ++k) { // Barrier to ensure all data are loaded into shared memory. @@ -177,6 +184,42 @@ __global__ void __launch_bounds__(Nthreads) printf("acc0: \n"); print(acc0), print("\n"); } + auto scores = + make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); + + Tensor m_old = make_fragment_like(m_new); + copy(m_new, m_old); + + Tensor scores_max = make_fragment_like(m_new); + + // Compute row max. + reduce_max<4, true>(scores, scores_max); + + // Compute new max vector. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = max(m_new(ax0), scores_max(ax0)); + } + + auto acco_rowcol = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + + // Renormalizatio for the previous block. + for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { + float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * scale; + for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { + acco_rowcol(ax0, ax1) *= scale; + } + } + + // Load V into register and issue MMA. + int split_n = kN / kTN - 1; + for (int n = 0; n < split_n; ++n) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.body(); + } } } diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 847711a..b78402a 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -32,6 +32,8 @@ void run(bool check = true) { static constexpr int kBatch = 1; + static constexpr int kWarpPerRow = 4; + static constexpr int kWarpPerCol = 1; static constexpr int kThreads = 128; static constexpr int kStagesQK = 2; static constexpr int kStagesV = 2; @@ -118,7 +120,8 @@ void run(bool check = true) { using Traits = benchmarks::cutlass_wrapper::FATraits; + kTM, kTN, kTK, kTP, kWarpPerRow, + kWarpPerCol, kThreads>; auto fa_kernel = benchmarks::cutlass_wrapper::fa_kernel + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +struct MaxOp_float { + DEVICE float operator()(float const& x, float const& y) { + return max(x, y); + } +}; + +template +struct SumOp { + DEVICE T operator()(T const& x, T const& y) { return x + y; } +}; + +template +struct SumAbsOp { + DEVICE T operator()(T const& x, T const& y) { return x + abs(y); } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || + THREADS == 4); + template + static DEVICE T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static DEVICE T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template +DEVICE void thread_reduce_(cute::Tensor const& tensor, + cute::Tensor& summary, + Operator& op) { + using namespace cute; + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = + zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +DEVICE void quad_allreduce_(cute::Tensor& dst, + cute::Tensor& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +DEVICE void eight_allreduce_(cute::Tensor& dst, + cute::Tensor& src, + Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<8>::run(src(i), op); + } +} + +template +DEVICE void allreduce_(cute::Tensor& dst, + cute::Tensor& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce::run(src(i), op); + } +} + +template +DEVICE void reduce_(cute::Tensor const& tensor, + cute::Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + allreduce_(summary, summary, op); +} + +template +DEVICE void reduce_max(cute::Tensor const& tensor, + cute::Tensor& max) { + MaxOp_float max_op; + reduce_(tensor, max, max_op); +} + +template +DEVICE void reduce_sum(cute::Tensor const& tensor, + cute::Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +template +DEVICE void reduce_sumabs(cute::Tensor const& tensor, + cute::Tensor& sum) { + SumAbsOp sumabs_op; + reduce_(tensor, sum, sumabs_op); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks \ No newline at end of file