Skip to content

Commit

Permalink
Update FlashAttention to v2.0.6 to test
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 14, 2023
1 parent 366754a commit 1ebab9b
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 272 deletions.
9 changes: 0 additions & 9 deletions .github/workflows/cuda/cu117-Linux-env.sh

This file was deleted.

18 changes: 0 additions & 18 deletions .github/workflows/cuda/cu117-Linux.sh

This file was deleted.

53 changes: 0 additions & 53 deletions .github/workflows/env.sh

This file was deleted.

File renamed without changes.
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 735 files
282 changes: 164 additions & 118 deletions csrc/flash_attn/src/flash_bwd_kernel.h

Large diffs are not rendered by default.

69 changes: 39 additions & 30 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev);
cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Expand All @@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T

template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Expand All @@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};

Expand Down Expand Up @@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});

auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);

Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Expand All @@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling
//

auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}

auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);

auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);

// TODO: this might need to change if we change the mma instruction in SM70
Expand Down Expand Up @@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }

Expand All @@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}

int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
Expand All @@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}

auto seeds = at::cuda::philox::unpack(params.philox_args);
Expand Down Expand Up @@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();

flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }

Expand Down Expand Up @@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
Expand All @@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
Expand All @@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// if (cute::thread0()) { print(tOrP); }

flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }

// This check is at the end of the loop since we always have at least 1 iteration
Expand All @@ -434,19 +440,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();

flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);

flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
Expand All @@ -464,20 +471,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps);
}

flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}

// Epilogue
Expand All @@ -501,15 +508,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)

// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }

copy(smem_thr_copy_O, taccOrO, taccOsO);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
Expand All @@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});

auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);

__syncthreads();

Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO);
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);

Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
Expand All @@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, false, true, true, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
Expand Down
Loading

0 comments on commit 1ebab9b

Please sign in to comment.