Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bench): Add pipeline FlashAttention-2 implementation. #23

Merged
merged 7 commits into from
Jan 3, 2025

Conversation

KuangjuX
Copy link
Collaborator

@KuangjuX KuangjuX commented Dec 21, 2024

This is a basic version of the pipelined FlashAttention-2 implementation, and I would like to first merge these changes into the master branch.

The current version of FlashAttention has the following features:

  • Pipeline technology has been adopted to create a multi-level cache for shared memory, allowing the use of upper-level caches during async_copy, which improves the utilization of the computational components(Tensor Core in Ampere Architecture).
  • Currently, only the case of load_q_once has been implemented, where kTK == kK. In this situation, the k dimension is not partitioned within a single SM Block, and the Q matrix only needs to be loaded once.
  • In FractalTensor, the N dimension is partitioned twice: once for kN in the outer loop and once for kTN in the inner loop to load the V matrix. The inner loop partitioning has not been implemented yet.
  • In FractalTensor, the last iteration of the outer loop for the N dimension is to be unrolled, which has not been implemented yet.
  • The naming of some device functions needs to be modified and organized.

The current implementation is not a final version; I will continue to add more features in subsequent PRs.

@KuangjuX KuangjuX marked this pull request as draft December 21, 2024 15:45
@KuangjuX
Copy link
Collaborator Author

@microsoft-github-policy-service agree company="Microsoft"

"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
"gotoSymbolStack.filePositionInfo": []
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious as to why the pre-commit hooks (see: https://github.com/microsoft/TileFusion/blob/master/.pre-commit-config.yaml#L28) do not address these unseen characters, which are often caused by differences in IDEs. I have observed this issue several times. This hook is supposed to fix it automatically before filing a PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just used pre-commit run --all-files to automatically fix the issues, but it seems that when I use Git to commit, it doesn't automatically fix all files before the pre-commit hook. I will check the reason for this issue later.

@KuangjuX KuangjuX force-pushed the cutlass_flash_attn branch from 44fba6a to e57fa5c Compare January 3, 2025 02:19
@KuangjuX KuangjuX changed the title 🚧 feat(bench): Add pipeline FlashAttention-2 implementation. feat(bench): Add pipeline FlashAttention-2 implementation. Jan 3, 2025
@KuangjuX KuangjuX marked this pull request as ready for review January 3, 2025 03:00
@KuangjuX KuangjuX requested a review from lcy-seso January 3, 2025 03:00
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
project(gemm_bench LANGUAGES C CXX CUDA)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the project name gemm_bench should be updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! I forgot to make the modifications, but they have been made now.

include_directories("${THIRD_PARTY_DIR}/cutlass/include")

add_executable(flash_attn main.cu)
target_link_libraries(flash_attn ${CUDA_CUBLAS_LIBRARIES})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CuBLAS utilized in this code? It doesn't appear to be. Do we need to link it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 😊

@lcy-seso lcy-seso merged commit b586a02 into microsoft:master Jan 3, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants