Skip to content

Commit

Permalink
style(kernel): 借助 cub 基础设施简化代码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 19, 2023
1 parent 5d78855 commit bed3627
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 24 deletions.
13 changes: 6 additions & 7 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@ namespace refactor::kernel {
template<class T>
struct QuantizeMapMinMaxFunctor {
__device__ __forceinline__ QuantizeMinMax<T>
operator()(T x) const {
operator()(T x) const noexcept {
return {x, x};
}
};

template<class T>
struct QuantizeReduceMinMaxFunctor {
__device__ __forceinline__ QuantizeMinMax<T>
operator()(QuantizeMinMax<T> a, QuantizeMinMax<T> b) const {
return {a.min < b.min ? a.min : b.min,
a.max > b.max ? a.max : b.max};
operator()(QuantizeMinMax<T> a, QuantizeMinMax<T> b) const noexcept {
return {CUB_MIN(a.min, b.min), CUB_MAX(a.max, b.max)};
}
};

Expand Down Expand Up @@ -56,8 +55,8 @@ namespace refactor::kernel {
TO *__restrict__ zp_) {

auto const [min, max] = *minmax;
auto temp = QuantizeReduceMinMaxFunctor<TI>{}({min, max}, {ZERO<TI>, ZERO<TI>});
auto scale = (temp.max - temp.min) / QLEN<TI, TO>;
auto cover0 = QuantizeReduceMinMaxFunctor<TI>{}({min, max}, {ZERO<TI>, ZERO<TI>});
auto scale = (cover0.max - cover0.min) / QLEN<TI, TO>;
auto zp = static_cast<TO>(round(QMIN<TI, TO> - min / scale));

auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -73,7 +72,7 @@ namespace refactor::kernel {
}
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
auto K::lower(Resources &) const -> RoutineWorkspace {
using namespace runtime;
using TI = float;
using TO = uint8_t;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace refactor::kernel {
size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const noexcept final;
RoutineWorkspace lower(Resources &) const final;
#endif
};

Expand Down
19 changes: 3 additions & 16 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
namespace refactor::kernel {
using namespace runtime;

template<class T>
__device__ __forceinline__ T max_(T a, T b) { return a > b ? a : b; }

template<class T>
__device__ __forceinline__ T exp_(T x);
template<> __device__ __forceinline__ float exp_<float>(float x) { return expf(x); }
Expand Down Expand Up @@ -58,16 +55,6 @@ namespace refactor::kernel {
}
}

template<class T> struct SumOp {
__device__ __forceinline__ T operator()(T const &a, T const &b) const {
return a + b;
}
};
template<class T> struct MaxOp {
__device__ __forceinline__ T operator()(T const &a, T const &b) const {
return max_(a, b);
}
};
template<class T, class ReductionOp>
__device__ __forceinline__ T WarpAllReduce(T val, ReductionOp op) {
for (int mask = blockDim.x >> 1; mask > 0; mask >>= 1) {
Expand All @@ -92,9 +79,9 @@ namespace refactor::kernel {

T maxData = -__FLT_MAX__;
for (int i = threadIdx.x; i < dimsize; i += blockDim.x) {
maxData = max_(maxData, input[tid + i * stride]);
maxData = CUB_MAX(maxData, input[tid + i * stride]);
}
maxData = WarpAllReduce(maxData, MaxOp<T>{});
maxData = WarpAllReduce(maxData, cub::Max());
if (threadIdx.x == 0) {
maxTotal[threadIdx.y] = maxData;
}
Expand All @@ -104,7 +91,7 @@ namespace refactor::kernel {
for (int i = threadIdx.x; i < dimsize; i += blockDim.x) {
sumData += exp_(input[tid + i * stride] - maxTotal[threadIdx.y]);
}
sumData = WarpAllReduce(sumData, SumOp<T>{});
sumData = WarpAllReduce(sumData, cub::Sum());
if (threadIdx.x == 0) {
sumTotal[threadIdx.y] = sumData;
}
Expand Down

0 comments on commit bed3627

Please sign in to comment.