This repository has been archived by the owner on Jan 8, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrel_pos_cuda_kernel.cu
361 lines (340 loc) · 17.2 KB
/
rel_pos_cuda_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <iostream>
#include <stdio.h>
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
namespace {
template <typename scalar_t>
__global__ void fuse_all_kernel(
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> logits,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> q,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> k,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> rh,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> rw,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> uk,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> uh,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> uw,
const torch::PackedTensorAccessor<bool, 3, torch::RestrictPtrTraits, size_t> m,
const int h_q, const int w_q, const int h_k, const int w_k, const int d, const int num_heads) {
const int kk = blockIdx.y; // B * N
const int i = blockIdx.x; // Hq*Wq
const int j = threadIdx.x; // Hk*Wk
const int head_index = kk % num_heads;
const int r_h_index = j/w_k - i/w_q + h_q - 1;
const int r_w_index = j%w_k - i%w_q + w_q - 1;
extern __shared__ __align__(sizeof(scalar_t)) unsigned char _shared_memory_ptr[];
scalar_t* _shared_memory = reinterpret_cast<scalar_t *>(_shared_memory_ptr);
scalar_t* uk_shared = _shared_memory;
scalar_t* uh_shared = &_shared_memory[d];
scalar_t* uw_shared = &_shared_memory[2 * d];
if(j < d){
uk_shared[j] = uk[head_index][j];
uh_shared[j] = uh[head_index][j];
uw_shared[j] = uw[head_index][j];
}
__syncthreads();
scalar_t res = 0.0;
for(int dd = 0; dd < d; ++dd){
scalar_t q_d = q[kk][i][dd];
res += (q_d + uk_shared[dd]) * k[kk][j][dd] + (q_d + uh_shared[dd]) * rh[head_index][r_h_index][dd] + (q_d + uw_shared[dd]) * rw[head_index][r_w_index][dd];
}
logits[kk][i][j] = res + (m[kk][i][j] ? -10000.0 : 0.0);
}
template <typename scalar_t>
__global__ void relative_positioning_forward_2d_kernel(
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> r_h,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> r_w,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> new_logits,
const torch::PackedTensorAccessor<bool, 3, torch::RestrictPtrTraits, size_t> mask,
const int h_q, const int w_q, const int h_k, const int w_k, const int mask_ndim,
const bool use_shared_memory, const bool use_mask) {
const int k = blockIdx.y; //N
const int i = blockIdx.x; //Hq*Wq
const int j = threadIdx.x; //Hk*Wk
const int r_h_index = j/w_k - i/w_q + h_q - 1;
const int r_w_index = j%w_k - i%w_q + w_q - 1;
if(use_shared_memory){
extern __shared__ __align__(sizeof(scalar_t)) unsigned char _shared_memory_ptr[];
scalar_t *_shared_memory = reinterpret_cast<scalar_t *>(_shared_memory_ptr);
scalar_t* r_h_shared = _shared_memory;
scalar_t* r_w_shared = &_shared_memory[h_k + h_q - 1];
if(j < (h_k + h_q - 1))
r_h_shared[j] = r_h[k][i][j];
if(j < (w_k + w_q - 1))
r_w_shared[j] = r_w[k][i][j];
__syncthreads();
new_logits[k][i][j] += r_h_shared[r_h_index] + r_w_shared[r_w_index];
} else
new_logits[k][i][j] += r_h[k][i][r_h_index] + r_w[k][i][r_w_index];
if(use_mask)
new_logits[k][i][j] += mask[mask_ndim == 2 ? 0 : k][i][j] ? -10000.0 : 0.0;
}
template <typename scalar_t>
__global__ void relative_positioning_forward_3d_kernel(
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> r_t,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> r_h,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> r_w,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> new_logits,
const torch::PackedTensorAccessor<bool, 3, torch::RestrictPtrTraits, size_t> mask,
const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k,
const int mask_ndim, const bool use_shared_memory, const bool use_mask) {
const int k = blockIdx.z; // N
const int i = blockIdx.y; // Tq * Hq * Wq
const int k_t = blockIdx.x; // Tk
const int j = threadIdx.x; // Hk * Wk
const int q_t = i / (h_q * w_q);
const int q_h = (i % (h_q * w_q)) / w_q;
const int q_w = i % w_q;
const int k_h = j / w_k;
const int k_w = j % w_k;
const int r_t_index = k_t - q_t + t_q - 1;
const int r_h_index = k_h - q_h + h_q - 1;
const int r_w_index = k_w - q_w + w_q - 1;
const int l = k_t * h_k * w_k + j;
if(use_shared_memory){
extern __shared__ __align__(sizeof(scalar_t)) unsigned char _shared_memory_ptr[];
scalar_t *_shared_memory = reinterpret_cast<scalar_t *>(_shared_memory_ptr);
scalar_t* r_h_shared = _shared_memory;
scalar_t* r_w_shared = &_shared_memory[h_k + h_q - 1];
if(j < (h_k + h_q - 1))
r_h_shared[j] = r_h[k][i][j];
if(j < (w_k + w_q - 1))
r_w_shared[j] = r_w[k][i][j];
__syncthreads();
new_logits[k][i][l] += r_t[k][i][r_t_index] + r_h_shared[r_h_index] + r_w_shared[r_w_index] + ((use_mask && mask[mask_ndim == 2 ? 0 : k][i][l]) ? -10000.0 : 0.0);
} else
new_logits[k][i][l] += r_t[k][i][r_t_index] + r_h[k][i][r_h_index] + r_w[k][i][r_w_index] + ((use_mask && mask[mask_ndim == 2 ? 0 : k][i][l]) ? -10000.0 : 0.0);
}
template <typename scalar_t>
__global__ void relative_positioning_backward_2d_kernel_h(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_out,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_h_r,
const int h_q, const int w_q, const int h_k, const int w_k) {
//Wq * Hk threads(=> N * Hq blocks), each summing over Wk sequential elements and then place them in the right location
const int k = blockIdx.y; //N
const int i = blockIdx.x; //Hq
const int l = threadIdx.y; //Wq
const int h = threadIdx.x; //Hk
scalar_t ans = 0.0;
for(int w = 0; w < w_k; ++w)
ans += grad_out[k][i * w_q + l][h * w_k + w];
grad_h_r[k][i * w_q + l][h - i + h_q - 1] = ans;
}
template <typename scalar_t>
__global__ void relative_positioning_backward_2d_kernel_w(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_out,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_w_r,
const int h_q, const int w_q, const int h_k, const int w_k) {
//Wq * Hk threads(=> N * Wh blocks), each summing over Hk non-sequential elements and then place them in the right location
const int k = blockIdx.y; //N
const int i = blockIdx.x; //Hq
const int l = threadIdx.y; //Wq
const int w = threadIdx.x; //Wk
scalar_t ans = 0.0;
for(int h = 0; h < h_k; ++h)
ans += grad_out[k][i * w_q + l][h * w_k + w];
grad_w_r[k][i * w_q + l][w - l + w_q - 1] = ans;
}
template <typename scalar_t>
__global__ void relative_positioning_backward_2d_kernel_place(
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> sum_out,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits,size_t> grad_x_r,
const int w_q, const int x_q, const int mode) {
const int k = blockIdx.y; //N
const int i = blockIdx.x; //Hq
const int l = threadIdx.y; //Wq
const int x = threadIdx.x; //Hk(mode == 0) or Wk(mode == 1)
//mode == 0; x_q == h_q
//mode == 1; x_q == w_q
int q_x;
if(mode == 0)
q_x = i;
else
q_x = l;
grad_x_r[k][i * w_q + l][x - q_x + x_q - 1] = sum_out[k][i * w_q + l][x];
}
template <typename scalar_t>
__global__ void relative_positioning_backward_3d_kernel_place(
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> sum_out,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits,size_t> grad_x_r,
const int h_q, const int w_q, const int x_q, const int mode) {
const int k = blockIdx.y; //N
const int i = blockIdx.x; //Tq*Hq
const int l = threadIdx.y; //Wq
const int j = threadIdx.x; //Tk or Hk or Wk
int q_x;
if(mode == 0)
q_x = i / h_q; // q_t
else if(mode == 1)
q_x = i % h_q; // q_h
else
q_x = l; // q_k
grad_x_r[k][i * w_q + l][j - q_x + x_q - 1] = sum_out[k][i * w_q + l][j];
}
} // namespace
torch::Tensor relative_positioning_forward_2d_cuda(
torch::Tensor logits, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int h_q, const int w_q, const int h_k, const int w_k, const bool use_mask) {
auto new_logits = logits.clone();
const dim3 blocks(h_q * w_q, logits.size(0));
const int threads = h_k * w_k;
const bool use_shared_memory = (w_k * h_k + 1) >= ((h_q + h_k) > (w_q + w_k) ? (h_q + h_k) : (w_q + w_k));
const int mask_ndim = mask.ndimension();
const int shared_memory_amount = use_shared_memory ? (h_k + h_q - 1 + w_k + w_q - 1) : 0;
mask = mask_ndim == 2 ? mask.unsqueeze(0) : mask;
AT_DISPATCH_FLOATING_TYPES(logits.type(), "relative_positioning_forward_2d_kernel", ([&] {
relative_positioning_forward_2d_kernel<scalar_t><<<blocks, threads, shared_memory_amount*sizeof(scalar_t)>>>(
r_h.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
r_w.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
new_logits.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
mask.packed_accessor<bool, 3, torch::RestrictPtrTraits, size_t>(),
h_q, w_q, h_k, w_k, mask_ndim, use_shared_memory, use_mask);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
return new_logits;
}
torch::Tensor fuse_all_cuda(
torch::Tensor q, torch::Tensor k, torch::Tensor rh, torch::Tensor rw,
torch::Tensor uk, torch::Tensor uh, torch::Tensor uw, torch::Tensor m,
const int h_q, const int w_q, const int h_k, const int w_k, const int num_heads) {
auto logits = torch::zeros({q.size(0), h_q * w_q, h_k * w_k},
torch::dtype(q.dtype()).device(q.device()));
const dim3 blocks(h_q * w_q, q.size(0));
const dim3 threads(h_k * w_k);
const int d = q.size(2);
AT_DISPATCH_FLOATING_TYPES(logits.type(), "fuse_all_kernel", ([&] {
fuse_all_kernel<scalar_t><<<blocks, threads, 3*d*sizeof(scalar_t)>>>(
logits.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
q.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
k.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
rh.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
rw.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
uk.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
uh.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
uw.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
m.packed_accessor<bool, 3, torch::RestrictPtrTraits, size_t>(),
h_q, w_q, h_k, w_k, d, num_heads);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
return logits;
}
std::vector<torch::Tensor> relative_positioning_backward_2d_cuda(
torch::Tensor grad_out, const int h_q, const int w_q, const int h_k, const int w_k) {
const dim3 blocks(h_q, grad_out.size(0));
auto grad_out_view = grad_out.view({grad_out.size(0), grad_out.size(1), h_k, w_k});
auto grad_h_r = torch::zeros({grad_out.size(0), grad_out.size(1), h_k + h_q - 1},
torch::dtype(grad_out.dtype()).device(grad_out.device()));
const dim3 threads_h(h_k, w_q);
{
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "relative_positioning_backward_2d_kernel_h", ([&] {
relative_positioning_backward_2d_kernel_h<scalar_t><<<blocks, threads_h>>>(
grad_out.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h_r.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
h_q, w_q, h_k, w_k);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
}
auto grad_w_r = torch::zeros({grad_out.size(0), grad_out.size(1), w_k + w_q - 1},
torch::dtype(grad_out.dtype()).device(grad_out.device()));
const dim3 threads_w(w_k, w_q);
{
auto grad_out_sum_h = grad_out_view.sum({2});
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "relative_positioning_backward_2d_kernel_place_w", ([&] {
relative_positioning_backward_2d_kernel_place<scalar_t><<<blocks, threads_w>>>(
grad_out_sum_h.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_w_r.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
w_q, w_q, 1);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
}
return {grad_h_r, grad_w_r};
}
torch::Tensor relative_positioning_forward_3d_cuda(
torch::Tensor logits, torch::Tensor r_t, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k, const bool use_mask) {
auto new_logits = logits.clone();
const dim3 blocks(t_k, t_q * h_q * w_q, logits.size(0));
const int threads = h_k * w_k;
const bool use_shared_memory = (w_k * h_k + 1) >= ((h_q + h_k) > (w_q + w_k) ? (h_q + h_k) : (w_q + w_k));
const int mask_ndim = mask.ndimension();
const int shared_memory_amount = use_shared_memory ? (h_k + h_q - 1 + w_k + w_q - 1) : 0;
mask = mask_ndim == 2 ? mask.unsqueeze(0) : mask;
AT_DISPATCH_FLOATING_TYPES(logits.type(), "relative_positioning_forward_3d_kernel", ([&] {
relative_positioning_forward_3d_kernel<scalar_t><<<blocks, threads, shared_memory_amount*sizeof(scalar_t)>>>(
r_t.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
r_h.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
r_w.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
new_logits.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
mask.packed_accessor<bool, 3, torch::RestrictPtrTraits, size_t>(),
t_q, h_q, w_q, t_k, h_k, w_k, mask_ndim, use_shared_memory, use_mask);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
return new_logits;
}
std::vector<torch::Tensor> relative_positioning_backward_3d_cuda(
torch::Tensor grad_out, const int t_q, const int h_q, const int w_q,
const int t_k, const int h_k, const int w_k) {
const dim3 blocks(t_q * h_q, grad_out.size(0));
auto grad_out_view = grad_out.view({grad_out.size(0), grad_out.size(1), t_k, h_k, w_k});
auto grad_t_r = torch::zeros({grad_out.size(0), grad_out.size(1), t_k + t_q - 1},
torch::dtype(grad_out.dtype()).device(grad_out.device()));
auto grad_h_r = torch::zeros({grad_out.size(0), grad_out.size(1), h_k + h_q - 1},
torch::dtype(grad_out.dtype()).device(grad_out.device()));
{
auto grad_out_sum_w = grad_out_view.sum({4});
{
const dim3 threads_t(t_k, w_q);
auto grad_out_sum_w_h = grad_out_sum_w.sum({3});
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "relative_positioning_backward_3d_kernel_place_t", ([&] {
relative_positioning_backward_3d_kernel_place<scalar_t><<<blocks, threads_t>>>(
grad_out_sum_w_h.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_t_r.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
h_q, w_q, t_q, 0);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
}
{
const dim3 threads_h(h_k, w_q);
auto grad_out_sum_w_t = grad_out_sum_w.sum({2});
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "relative_positioning_backward_3d_kernel_place_t", ([&] {
relative_positioning_backward_3d_kernel_place<scalar_t><<<blocks, threads_h>>>(
grad_out_sum_w_t.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h_r.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
h_q, w_q, h_q, 1);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
}
}
auto grad_w_r = torch::zeros({grad_out.size(0), grad_out.size(1), w_k + w_q - 1},
torch::dtype(grad_out.dtype()).device(grad_out.device()));
{
const dim3 threads_w(w_k, w_q);
auto grad_out_sum_h_t = grad_out_view.sum({2, 3});
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "relative_positioning_backward_3d_kernel_place_t", ([&] {
relative_positioning_backward_3d_kernel_place<scalar_t><<<blocks, threads_w>>>(
grad_out_sum_h_t.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_w_r.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
h_q, w_q, w_q, 2);
}));
gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());
}
return {grad_t_r, grad_h_r, grad_w_r};
}