Skip to content

Commit

Permalink
fix pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 25, 2024
1 parent f3a8d2c commit b2ba0f9
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 43 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"span": "cpp",
"bitset": "cpp",
"initializer_list": "cpp",
"utility": "cpp"
"utility": "cpp",
"*.tcc": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
Expand Down
155 changes: 125 additions & 30 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,24 @@ class G2SCopyV {
int num_stage;
};

template <typename SQTensor, typename RQTensor, typename SKTensor,
typename RKTensor, typename RAccTensor, typename TiledCopyQ,
typename TiledCopyK, typename TiledMma>
template <typename SQTensor, typename RQMmaView, typename RQCopyView,
typename SKTensor, typename RKMmaView, typename RKCopyView,
typename RAccTensor, typename TiledCopyQ, typename TiledCopyK,
typename TiledMma>
class S2RPipelineQK {
public:
DEVICE S2RPipelineQK(SQTensor& sQ, RQTensor& rQ, SKTensor& sK, RKTensor& rK,
DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view,
RQCopyView& rQ_copy_view, SKTensor& sK,
RKMmaView& rK_mma_view, RKCopyView& rK_copy_view,
RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k,
TiledMma tiled_mma, int sQ_stride, int sK_stride,
int num_stage = 2)
: sQ(sQ),
rQ(rQ),
rQ_mma_view(rQ_mma_view),
rQ_copy_view(rQ_copy_view),
sK(sK),
rK(rK),
rK_mma_view(rK_mma_view),
rK_copy_view(rK_copy_view),
acc(acc),
copy_q(copy_q),
copy_k(copy_k),
Expand All @@ -237,11 +242,89 @@ class S2RPipelineQK {
cur_iter(0),
cur_iter_sq(0) {}

DEVICE void print_rQ() {
if (thread0()) {
print(rQ_mma_view), print("\n");
print(rQ_copy_view), print("\n");
}
}

DEVICE void prologue() {
cur_iter = 0;
cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));

#pragma unroll
for (int i = 0; i < size<2>(rK_mma_view); ++i) {
if (i < size<2>(rK_mma_view) - 1) {
cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));
}
cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
acc);
}
sQ.data() = sQ.data() + sQ_stride;
sK.data() = sK.data() + sK_stride;

cur_iter++;
}

DEVICE void body() {
cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));

#pragma unroll
for (int i = 0; i < size<2>(rK_mma_view); ++i) {
if (i < size<2>(rK_mma_view) - 1) {
cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1));
cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
acc);
}
sQ.data() = sQ.data() + sQ_stride;
sK.data() = sK.data() + sK_stride;

if ((cur_iter + 1) % num_stage == 0) {
sK.data() = sK.data() + (-sK_stride * num_stage);
}

cur_iter++;
cur_iter_sq++;
}

DEVICE void epilogue() {
cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));

#pragma unroll
for (int i = 0; i < size<2>(rK_mma_view); ++i) {
if (i < size<2>(rK_mma_view) - 1) {
cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1));
cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
acc);
}

sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq);
sK.data() = sK.data() + sK_stride;

if ((cur_iter + 1) % num_stage == 0) {
sK.data() = sK.data() + (-sK_stride * num_stage);
}

cur_iter++;
cur_iter_sq = 0;
}

private:
SQTensor& sQ;
RQTensor& rQ;
RQMmaView& rQ_mma_view;
RQCopyView& rQ_copy_view;
SKTensor& sK;
RKTensor& rK;
RKMmaView& rK_mma_view;
RKCopyView& rK_copy_view;
RAccTensor& acc;
TiledCopyQ copy_q;
TiledCopyK copy_k;
Expand Down Expand Up @@ -305,41 +388,53 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride,
return copy_v;
}

template <typename SQTensor, typename SKTensor, typename RegAcc,
typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_qk(SQTensor sQ, SKTensor sK, RegAcc acc, int sQ_stride,
int sK_stride, SmemCopyAtom copy_atom = SmemCopyAtom{},
template <typename Element, typename SQLayout, typename SKLayout,
typename RegAcc, typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc,
int sQ_stride, int sK_stride,
SmemCopyAtom copy_atom = SmemCopyAtom{},
TiledMma tiled_mma = TiledMma{}) {
int tid = threadIdx.x;

auto sQ_ = make_tensor(make_smem_ptr(sQ_ptr), sQ_layout);
auto sK_ = make_tensor(make_smem_ptr(sK_ptr), sK_layout);

auto thr_mma = tiled_mma.get_thread_slice(tid);

auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma);
auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma);
auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid);
auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid);

auto rQ_org = thr_mma.partition_fragment_A(sQ);
auto rK_org = thr_mma.partition_fragment_B(sK);
auto sQ = s2r_thr_copy_q.partition_S(sQ_);
auto sK = s2r_thr_copy_k.partition_S(sK_);

auto rQ = s2r_thr_copy_q.retile_D(rQ_org);
auto rK = s2r_thr_copy_k.retile_D(rK_org);
// auto rAcc = get_acc<size<0>(rQ), size<1>(rK)>(tiled_mma);
// Thread partition for mma.
auto rQ_mma = thr_mma.partition_fragment_A(sQ_);
auto rK_mma = thr_mma.partition_fragment_B(sK_);

if (thread0()) {
printf("thr_mma: \n");
print(thr_mma), print("\n");
printf("s2r_copy_q: \n");
print(s2r_copy_q), print("\n");
printf("rQ_org: \n");
print(rQ_org), print("\n");
printf("rQ: \n");
print(rQ), print("\n");
}
// Thread partition for shared to register copy.
auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);
// auto rAcc = get_acc<size<0>(rQ), size<1>(rK)>(tiled_mma);

detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ, sK, rK, acc, s2r_copy_q,
s2r_copy_k, tiled_mma, sQ_stride,
sK_stride);
// if (thread0()) {
// printf("thr_mma: \n");
// print(thr_mma), print("\n");
// printf("s2r_copy_q: \n");
// print(s2r_copy_q), print("\n");
// printf("rQ_org: \n");
// print(rQ_org), print("\n");
// printf("rQ: \n");
// print(rQ), print("\n");
// }

detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma,
rK_copy, acc, s2r_copy_q, s2r_copy_k,
tiled_mma, sQ_stride, sK_stride);

return s2r_pipeline_qk;
}

} // namespace cutlass_wrapper
Expand Down
32 changes: 22 additions & 10 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,13 @@ struct FATraits : public Base {

template <typename Element, typename KeTraits, const int kM, const int kN,
const int kK, const int kP, const int kTM, const int kTN,
const int kTK, const int kTP, const int Nthreads>
const int kTK, const int kTP, const int Nthreads, const int kStagesQK,
const int kStageV>
__global__ void __launch_bounds__(Nthreads)
fa_kernel(const Element* dQ, const Element* dK, const Element* dV,
Element* dO) {
// constexpr float softmax_scale = 1.250000e-01f;

// Q, K: [batch, head, length, hidden_qk]
// V, O: [batch, head, length, hidden_v]

extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<Element*>(buf_);

Expand All @@ -114,8 +112,8 @@ __global__ void __launch_bounds__(Nthreads)
dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP;

Element* sQ_ptr = reinterpret_cast<Element*>(buf);
Element* sK_ptr = sQ_ptr + kTM * kTK;
Element* sV_ptr = sK_ptr + kTN * kTK;
Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK;
Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK;
// Element* sO_ptr = sQ_ptr;

typename KeTraits::TiledMma mma;
Expand All @@ -140,11 +138,18 @@ __global__ void __launch_bounds__(Nthreads)
g2s_copy_qk.print_gQ_data(0);
#endif

auto sQ = g2s_copy_qk.get_sQ();
auto sK = g2s_copy_qk.get_sK();
auto acc0 = get_acc<kTM, kTN>(mma);

make_s2r_qk(sQ, sK, acc0, kTK, kTK, typename KeTraits::SmemCopyAtom{}, mma);
if (thread0()) {
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n",
(int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0));
}

auto s2r_pipeline_qk =
make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{},
typename KeTraits::SmemLayoutK{}, acc0, kTK, kTK,
typename KeTraits::SmemCopyAtom{}, mma);
s2r_pipeline_qk.print_rQ();

// Issue global to shared memory copy before the main loop.
g2s_copy_qk.prologue();
Expand All @@ -159,12 +164,19 @@ __global__ void __launch_bounds__(Nthreads)
__syncthreads();
g2s_copy_qk.body();
// Load data from shared memory into register and issue MMA.
s2r_pipeline_qk.body();
}

cp_async_wait_flash<0>();
__syncthreads();
// g2s_copy_qk.print_sQ_data(0);
g2s_copy_v.prologue();
s2r_pipeline_qk.epilogue();

// Print acc0 data.
if (thread0()) {
printf("acc0: \n");
print(acc0), print("\n");
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions benchmarks/cpp/flashattention/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ void run(bool check = true) {
static constexpr int kBatch = 1;

static constexpr int kThreads = 128;
static constexpr int kStagesQK = 2;
static constexpr int kStagesV = 2;

static_assert(kK == kTK,
"The current implementation requires kTK == K for now.");
Expand Down Expand Up @@ -118,8 +120,10 @@ void run(bool check = true) {
benchmarks::cutlass_wrapper::FATraits<cutlass::half_t, kM, kN, kK, kP,
kTM, kTN, kTK, kTP, kThreads>;

auto fa_kernel = benchmarks::cutlass_wrapper::fa_kernel<
cutlass::half_t, Traits, kM, kN, kK, kP, kTM, kTN, kTK, kTP, kThreads>;
auto fa_kernel =
benchmarks::cutlass_wrapper::fa_kernel<cutlass::half_t, Traits, kM, kN,
kK, kP, kTM, kTN, kTK, kTP,
kThreads, kStagesQK, kStagesV>;

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
Expand Down

0 comments on commit b2ba0f9

Please sign in to comment.