forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflash_attn_1_fwd_f32.cu
140 lines (121 loc) · 4.93 KB
/
flash_attn_1_fwd_f32.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Modified from: https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define ENABLE_NOTE_LOG 0
__global__ void flash_attn_1_fwd_f32_kernel(
const float* Q,
const float* K,
const float* V,
const int N,
const int d,
const int Tc,
const int Tr,
const int Bc,
const int Br,
const float softmax_scale,
float* l,
float *m,
float* O) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
#pragma unroll
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads(); // such that the inner loop can use the correct Kj, Vj
#pragma unroll
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers
#pragma unroll
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
#pragma unroll
for (int y = 0; y < Bc; y++) {
float sum = 0;
#pragma unroll
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
#pragma unroll
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
// Write O, l, m to HBM
#pragma unroll
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
#pragma unroll
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop
}
}
torch::Tensor flash_attn_1_fwd_f32(
torch::Tensor Q,
torch::Tensor K,
torch::Tensor V) {
// TODO: determine Bc, Br dynamically
const int Bc = 32; const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// Initialize O, l, m to HBM
auto O = torch::zeros_like(Q);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, 0);
auto l = torch::zeros({B, nh, N}, options);
auto m = torch::full({B, nh, N}, -INFINITY, options);
// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
#if ENABLE_NOTE_LOG
printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);
#endif
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Bc); // Bc threads per block
flash_attn_1_fwd_f32_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
return O;
}