-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.cuh
69 lines (59 loc) · 1.72 KB
/
util.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <cuda.h>
#include <cuda_runtime.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// brought from vLLM code https://github.com/vllm-project/vllm/blob/main/csrc/reduction_utils.cuh
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}
template<typename T>
__inline__ __device__ T warpReduceMax(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
return val;
}
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
if(threadIdx.x == 0) {
for(int i = 0; i < 32;++i){
shared[i] = 0.0;
}
}
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
T ret = 0;
for(int j = 0; j < 32;++j){
ret += shared[j];
}
return ret;
}
template<typename T>
__inline__ __device__ T blockReduceMax(T val) {
static __shared__ T shared[32];
if(threadIdx.x == 0) {
for(int i = 0; i < 32;++i){
shared[i] = 0.0;
}
}
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
T ret = 0;
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax<T>(val);
return ret;
}