Skip to content

Commit

Permalink
feat(bench): Add pipeline FlashAttention-2 implementation. (#23)
Browse files Browse the repository at this point in the history
* Add pipeline FlashAttention-2 implementation.

* pre-commit fix.

* Add comments and fix some bugs.

* Add copyright.

* follow comments.
  • Loading branch information
KuangjuX authored Jan 3, 2025
1 parent 36cf17a commit b586a02
Show file tree
Hide file tree
Showing 10 changed files with 1,361 additions and 18 deletions.
35 changes: 17 additions & 18 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
{
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
"gotoSymbolStack.filePositionInfo": [],
"files.associations": {
"*.tcc": "cpp",
"optional": "cpp",
"ratio": "cpp",
"system_error": "cpp",
"array": "cpp",
"functional": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"variant": "cpp",
"compare": "cpp",
"concepts": "cpp",
"random": "cpp"
}
"files.associations": {
"array": "cpp",
"string": "cpp",
"string_view": "cpp",
"span": "cpp",
"bitset": "cpp",
"initializer_list": "cpp",
"utility": "cpp",
"*.tcc": "cpp",
"chrono": "cpp",
"random": "cpp",
"limits": "cpp",
"semaphore": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
"gotoSymbolStack.filePositionInfo": []
}
19 changes: 19 additions & 0 deletions benchmarks/cpp/flashattention/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
# MIT License.
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
project(flash_attention_bench LANGUAGES C CXX CUDA)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
"${PROJECT_SOURCE_DIR}/../../../cmake")
set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party")

include(generic)

include_directories("${PROJECT_SOURCE_DIR}/../../../include")
include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp")
include_directories("${THIRD_PARTY_DIR}/cutlass/include")

add_executable(flash_attn main.cu)
16 changes: 16 additions & 0 deletions benchmarks/cpp/flashattention/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


BUILD_DIR := build

.PHONY: build clean

build:
@mkdir -p $(BUILD_DIR)
@cd $(BUILD_DIR) && cmake .. && make -j$(proc)

clean:
@rm -rf $(BUILD_DIR)
71 changes: 71 additions & 0 deletions benchmarks/cpp/flashattention/convert.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "cuda_utils.cuh"

#include <cute/layout.hpp>
#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>

namespace benchmarks {
namespace cutlass_wrapper {

using namespace cute;

template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto convert_type(cute::Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}

template <typename Layout>
DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using namespace cute;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
auto l = logical_divide(rowcol_layout,
Shape<Underscore, Shape<Underscore, Int<2>>>{});

return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)),
get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)), get<1>(get<1>(get<1>(l))));
}

DEVICE auto convert_layout_C_Aregs() {
using namespace cute;
auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{};
auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{});

return make_layout(
make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))),
get<1>(l), get<1>(get<2>(l)));
}

template <class LayoutType>
DEVICE auto convert_layout_scores(LayoutType layout_s) {
using namespace cute;
static_assert(decltype(size<0>(layout_s))::value == 4);
static_assert(decltype(rank(layout_s))::value == 3);

auto l = logical_divide(layout_s, Shape<_2>{});
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)),
make_layout(get<0>(get<0>(l)), get<2>(l)));
}

template <int ATOMNUM, class LayoutType>
DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) {
using namespace cute;

auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{});
return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l)));
}

} // namespace cutlass_wrapper
} // namespace benchmarks
Loading

0 comments on commit b586a02

Please sign in to comment.