Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 26, 2024
1 parent b2ba0f9 commit 1bd3c5c
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 23 deletions.
6 changes: 5 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
"bitset": "cpp",
"initializer_list": "cpp",
"utility": "cpp",
"*.tcc": "cpp"
"*.tcc": "cpp",
"chrono": "cpp",
"random": "cpp",
"limits": "cpp",
"semaphore": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
Expand Down
68 changes: 68 additions & 0 deletions benchmarks/cpp/flashattention/convert.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include "cuda_utils.cuh"

#include <cute/layout.hpp>
#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>

namespace benchmarks {
namespace cutlass_wrapper {

using namespace cute;

template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto convert_type(cute::Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}

template <typename Layout>
DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using namespace cute;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
auto l = logical_divide(rowcol_layout,
Shape<Underscore, Shape<Underscore, Int<2>>>{});

return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)),
get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)), get<1>(get<1>(get<1>(l))));
}

DEVICE auto convert_layout_C_Aregs() {
using namespace cute;
auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{};
auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{});

return make_layout(
make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))),
get<1>(l), get<1>(get<2>(l)));
}

template <class LayoutType>
DEVICE auto convert_layout_scores(LayoutType layout_s) {
using namespace cute;
static_assert(decltype(size<0>(layout_s))::value == 4);
static_assert(decltype(rank(layout_s))::value == 3);

auto l = logical_divide(layout_s, Shape<_2>{});
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)),
make_layout(get<0>(get<0>(l)), get<2>(l)));
}

template <int ATOMNUM, class LayoutType>
DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) {
using namespace cute;

auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{});
return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l)));
}

} // namespace cutlass_wrapper
} // namespace benchmarks
93 changes: 81 additions & 12 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class G2SCopyQK {
DEVICE void print_gQ() {
if (thread0()) {
print(gQ), print("\n");
printf("gQ size<0>: %d, size<1>: %d, size<2>: %d\n",
(int)size<0>(gQ), (int)size<1>(gQ), (int)size<2>(gQ));
}
}

Expand Down Expand Up @@ -336,6 +338,56 @@ class S2RPipelineQK {
int cur_iter_sq;
};

template <typename SVTensor, typename RVMmaView, typename RVCopyView,
typename RegAcc, typename TiledCopy, typename TiledMma>
class S2RPipelineQK {
public:
DEVICE S2RPipelineQK(SVTensor& sV, RVMmaView& rV_mma_view,
RVCopyView& rV_copy_view, RegAcc& acc,
TiledCopy tiled_copy, TiledMma, tiled_mma,
int sV_stride, int num_stage = 2)
: sV(sV),
rV_mma_view(rV_mma_view),
rV_copy_view(rV_copy_view),
acc(acc),
tiled_copy(tiled_copy),
sV_stride(sV_stride),
num_stage(num_stage),
cur_iter(0),
cur_iter_sv(0) {}

template <typename RegValue>
DEVICE void prologue(RegValue& value) {
cur_iter = 0;
cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(rV_mma_view); ++i) {
if (i < size<2>(rV_mma_view) - 1) {
cute::copy(tiled_copy, sV(_, _, i + 1),
rV_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma,
value(_, _, cur_iter * size<2>(rV_mma_view) + i),
rV_mma_view(_, _, i), acc);
}

sV.data() = sV.data() + sV_stride;
cur_iter++;
}

private:
SVTensor& sV;
RVMmaView& rV_mma_view;
RVCopyView& rV_copy_view;
RegAcc& acc;
TiledCopy tiled_copy;
TiledMma tiled_mma;
int sV_stride;
int num_stage;
int cur_iter;
int cur_iter_sv;
}

} // namespace detail

template <typename Element, typename GlobalQLayout, typename SharedQLayout,
Expand All @@ -349,6 +401,13 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{});
auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{});

if (thread0()) {
printf("gQ: \n");
print(gQ), print("\n");
printf("size<0>(gQ): %d, size<1>(gQ): %d\n", (int)size<0>(gQ),
(int)size<1>(gQ));
}

auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{});
auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{});

Expand Down Expand Up @@ -417,18 +476,6 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
// Thread partition for shared to register copy.
auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);
// auto rAcc = get_acc<size<0>(rQ), size<1>(rK)>(tiled_mma);

// if (thread0()) {
// printf("thr_mma: \n");
// print(thr_mma), print("\n");
// printf("s2r_copy_q: \n");
// print(s2r_copy_q), print("\n");
// printf("rQ_org: \n");
// print(rQ_org), print("\n");
// printf("rQ: \n");
// print(rQ), print("\n");
// }

detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma,
rK_copy, acc, s2r_copy_q, s2r_copy_k,
Expand All @@ -437,5 +484,27 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
return s2r_pipeline_qk;
}

template <typename Element, typename SVLayout, typename SmemCopyAtom,
typename TiledMma>
DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride,
SmemCopyAtom copy_atom, TiledMma tiled_mma) {
int tid = threadIdx.x;

auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout);

auto thr_mma = tiled_mma.get_thread_slice(tid);

auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma);
auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid);

auto sV = s2r_thr_copy_v.partition_S(sV_);

auto rV_mma = thr_mma.partition_fragment_B(sV_);
auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma);
}

} // namespace cutlass_wrapper
} // namespace benchmarks

} // namespace cutlass_wrapper
} // namespace benchmarks
61 changes: 52 additions & 9 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

#pragma once

#include "convert.cuh"
#include "copy.cuh"
#include "cuda_utils.cuh"
#include "cutlass/copy.cuh"
#include "cutlass/traits_base.cuh"
#include "reduce.cuh"

namespace benchmarks {
namespace cutlass_wrapper {
Expand All @@ -17,8 +19,9 @@ using namespace cute;
/// @tparam Element_
template <typename Element_, const int kM, const int kN, const int kK,
const int kP, const int kTM, const int kTN, const int kTK,
const int kTP, const int kThreads, const int SmemKAtom = 64,
const int kSwizzle = 3, typename Base = AccessBase<Element_>>
const int kTP, const int kWarpPerRow, const int kWarpPerCol,
const int kThreads, const int SmemKAtom = 64, const int kSwizzle = 3,
typename Base = AccessBase<Element_>>
struct FATraits : public Base {
// Q: [kM, kK] --> [length, hidden_qk]
// K: [kN, kK] --> [length, hidden_qk]
Expand All @@ -27,10 +30,6 @@ struct FATraits : public Base {
// assert(kM == kN)
using Element = Element_;

// TODO: fix the hardcode.
static constexpr int kWarpPerRow = 1;
static constexpr int kWarpPerCol = 1;

// Declare global to shared memory copy layout.
using GmemLayoutQ = Layout<Shape<Int<kTM>, Int<kTK>>, Stride<Int<kK>, _1>>;
using GmemLayoutK = Layout<Shape<Int<kTN>, Int<kTK>>, Stride<Int<kK>, _1>>;
Expand Down Expand Up @@ -100,7 +99,7 @@ template <typename Element, typename KeTraits, const int kM, const int kN,
__global__ void __launch_bounds__(Nthreads)
fa_kernel(const Element* dQ, const Element* dK, const Element* dV,
Element* dO) {
// constexpr float softmax_scale = 1.250000e-01f;
constexpr float softmax_scale = 1.250000e-01f;

extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<Element*>(buf_);
Expand Down Expand Up @@ -139,6 +138,10 @@ __global__ void __launch_bounds__(Nthreads)
#endif

auto acc0 = get_acc<kTM, kTN>(mma);
auto acco = get_acc<kTM, kTP>(mma);

auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{});
auto lse_new = make_fragment_like(m_new);

if (thread0()) {
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n",
Expand All @@ -154,9 +157,13 @@ __global__ void __launch_bounds__(Nthreads)
// Issue global to shared memory copy before the main loop.
g2s_copy_qk.prologue();

for (int n = 0; n < kN; n += kTN) {
int split_k = kK / kTK - 1;
fill(lse_new, 0.0f);
fill(m_new, -INFINITY);
clear(acco);

int split_n = kN / kTN;
for (int n = 0; n < split_n; ++n) {
int split_k = kK / kTK - 1;
// Pipeline
for (int k = 0; k < split_k; ++k) {
// Barrier to ensure all data are loaded into shared memory.
Expand All @@ -177,6 +184,42 @@ __global__ void __launch_bounds__(Nthreads)
printf("acc0: \n");
print(acc0), print("\n");
}
auto scores =
make_tensor(acc0.data(), convert_layout_scores(acc0.layout()));

Tensor m_old = make_fragment_like(m_new);
copy(m_new, m_old);

Tensor scores_max = make_fragment_like(m_new);

// Compute row max.
reduce_max<4, true>(scores, scores_max);

// Compute new max vector.
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) {
m_new(ax0) = max(m_new(ax0), scores_max(ax0));
}

auto acco_rowcol =
make_tensor(acco.data(), convert_layout_scores(acco.layout()));

// Renormalizatio for the previous block.
for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) {
float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
lse_new(ax0) = lse_new(ax0) * scale;
for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) {
acco_rowcol(ax0, ax1) *= scale;
}
}

// Load V into register and issue MMA.
int split_n = kN / kTN - 1;
for (int n = 0; n < split_n; ++n) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_v.body();
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion benchmarks/cpp/flashattention/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ void run(bool check = true) {

static constexpr int kBatch = 1;

static constexpr int kWarpPerRow = 4;
static constexpr int kWarpPerCol = 1;
static constexpr int kThreads = 128;
static constexpr int kStagesQK = 2;
static constexpr int kStagesV = 2;
Expand Down Expand Up @@ -118,7 +120,8 @@ void run(bool check = true) {

using Traits =
benchmarks::cutlass_wrapper::FATraits<cutlass::half_t, kM, kN, kK, kP,
kTM, kTN, kTK, kTP, kThreads>;
kTM, kTN, kTK, kTP, kWarpPerRow,
kWarpPerCol, kThreads>;

auto fa_kernel =
benchmarks::cutlass_wrapper::fa_kernel<cutlass::half_t, Traits, kM, kN,
Expand Down
Loading

0 comments on commit 1bd3c5c

Please sign in to comment.