Skip to content

Commit

Permalink
feat(kernel): 现在单目算子支持更多类型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 7, 2023
1 parent 369f109 commit 94fde13
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 82 deletions.
3 changes: 1 addition & 2 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "cuda_kernel.hh"
#include <execution>
#include <unordered_set>

namespace refactor::kernel {
Expand All @@ -19,7 +18,7 @@ namespace refactor::kernel {
Op::Tanh,
Op::Neg,
};
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
return supportedOp.contains(op) && a.dataType.isNumberic()
? std::make_unique<K>(op, a.dataType, a.elementsSize())
: nullptr;
}
Expand Down
68 changes: 47 additions & 21 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,39 @@ namespace refactor::kernel {
using DT = DataType;

template<class T> struct AbsFunctor {
__device__ T operator()(T x) const { return abs(x); }
__device__ T operator()(T x) const { return x >= static_cast<T>(0) ? x : -x; }
};
template<class T> struct NegFunctor {
__device__ T operator()(T x) const { return -x; }
};
template<class T> struct ReluFunctor {
__device__ T operator()(T x) const { return x > 0 ? x : 0; }
};
template<class T> struct SqrtFunctor {
__device__ T operator()(T x) const {
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
return static_cast<T>(sqrt(static_cast<M>(x)));
}
};
template<class T> struct SigmoidFunctor {
__device__ T operator()(T x) const {
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
return static_cast<T>(1 / (1 + std::exp(-static_cast<M>(x))));
}
__device__ T operator()(T x) const { return x > static_cast<T>(0) ? x : static_cast<T>(0); }
};

template<class T> struct SqrtFunctor {};
#define SQRT_FN(TY, FN) \
template<> struct SqrtFunctor<TY> { \
__device__ TY operator()(TY x) const { return FN(x); } \
}
SQRT_FN(nv_bfloat16, hsqrt);
SQRT_FN(half, hsqrt);
SQRT_FN(float, sqrtf);
SQRT_FN(double, sqrt);
#undef SQRT_FN

template<class T> constexpr __device__ __forceinline__ T reciprocal(T x) { return 1 / x; }
template<> __device__ __forceinline__ float reciprocal(float x) { return fdividef(1.0f, x); }
template<class T> struct SigmoidFunctor {};
#define SIGMOID_FN(TY, RECIPROCAL, EXP) \
template<> struct SigmoidFunctor<TY> { \
__device__ TY operator()(TY x) const { return RECIPROCAL(static_cast<TY>(1) + EXP(x)); } \
}
SIGMOID_FN(nv_bfloat16, hrcp, hexp);
SIGMOID_FN(half, hrcp, hexp);
SIGMOID_FN(float, reciprocal, expf);
SIGMOID_FN(double, reciprocal, exp);
#undef SIGMOID_FN

template<class T> struct TanhFunctor {
__device__ T operator()(T x) const {
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
Expand All @@ -47,16 +60,28 @@ namespace refactor::kernel {
};
}

template<decltype(DT::internal) DT_> struct cudatype {
using type = typename primitive<DT_>::type;
};
template<> struct cudatype<DT::FP16> {
using type = half;
};
template<> struct cudatype<DT::BF16> {
using type = nv_bfloat16;
};

#define CASE(FUNC, TYPE) \
case DT::TYPE: \
return lowerTyped<primitive<DT::TYPE>::type, FUNC##Functor<primitive<DT::TYPE>::type>>(size)
return lowerTyped<cudatype<DT::TYPE>::type, FUNC##Functor<cudatype<DT::TYPE>::type>>(size)
#define COPY \
return [size = size * dataType.size()](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { \
cudaMemcpyAsync(outputs[0], inputs[0], size, cudaMemcpyDeviceToDevice); \
}
#define GROUP_F(FUNC) \
CASE(FUNC, F32); \
CASE(FUNC, F64)
CASE(FUNC, F64); \
CASE(FUNC, FP16); \
CASE(FUNC, BF16)
#define GROUP_I(FUNC) \
CASE(FUNC, I8); \
CASE(FUNC, I16); \
Expand All @@ -68,7 +93,8 @@ namespace refactor::kernel {
CASE(FUNC, U32); \
CASE(FUNC, U64)

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
auto
K::lower(Resources &) const noexcept -> RoutineWorkspace {
switch (opType) {
case Op::Abs:
switch (dataType) {
Expand Down Expand Up @@ -97,16 +123,16 @@ namespace refactor::kernel {
case Op::Sqrt:
switch (dataType) {
GROUP_F(Sqrt);
GROUP_I(Sqrt);
GROUP_U(Sqrt);
// GROUP_I(Sqrt);
// GROUP_U(Sqrt);
default:
UNREACHABLE();
}
case Op::Sigmoid:
switch (dataType) {
GROUP_F(Sigmoid);
GROUP_I(Sigmoid);
GROUP_U(Sigmoid);
// GROUP_I(Sigmoid);
// GROUP_U(Sigmoid);
default:
UNREACHABLE();
}
Expand Down
60 changes: 57 additions & 3 deletions src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#include "cudnn_activation_kernel.hh"
#include "kernel/collectors/simple_unary.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
#include <unordered_set>

#ifdef USE_CUDA
#include "../../utilities/cuda/cudnn_context.hh"
#include "../../utilities/cuda/cudnn_functions.h"
#include <cudnn.h>
#endif

namespace refactor::kernel {
using K = ActivationCudnn;
using DT = DataType;
Expand All @@ -19,7 +23,7 @@ namespace refactor::kernel {
return nullptr;
#endif

return ARTHIMETIC.contains(op) && a.dataType.isCpuNumberic()
return ARTHIMETIC.contains(op) && a.dataType.isNumberic()
? std::make_unique<K>(op, a.dataType, static_cast<int>(a.elementsSize()))
: nullptr;
}
Expand All @@ -33,4 +37,54 @@ namespace refactor::kernel {
return "Performing activation using CUDNN";
}

#ifdef USE_CUDA

auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace cudnn;
using namespace runtime;

// RAII for closure
struct Descriptors {
cudnnActivationDescriptor_t activation;
cudnnTensorDescriptor_t tensor;

Descriptors() : activation(nullptr), tensor(nullptr) {
CUDNN_ASSERT(cudnnCreateActivationDescriptor(&activation));
CUDNN_ASSERT(cudnnCreateTensorDescriptor(&tensor));
}
~Descriptors() noexcept(false) {
CUDNN_ASSERT(cudnnDestroyActivationDescriptor(activation));
CUDNN_ASSERT(cudnnDestroyTensorDescriptor(tensor));
}

Descriptors(const Descriptors &) = delete;
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>();

// clang-format off
auto mode = type == Op::Relu ? CUDNN_ACTIVATION_RELU
: type == Op::Sigmoid ? CUDNN_ACTIVATION_SIGMOID
: type == Op::Tanh ? CUDNN_ACTIVATION_TANH
: UNREACHABLEX(cudnnActivationMode_t, "");
// clang-format on
CUDNN_ASSERT(cudnnSetActivationDescriptor(d->activation, mode, CUDNN_PROPAGATE_NAN, 0.0));
CUDNN_ASSERT(cudnnSetTensor4dDescriptor(d->tensor, CUDNN_TENSOR_NCHW, cudnnDataTypeConvert(dataType), 1, 1, 1, size));

res.fetchOrStore<CudnnContext>();
// nvcc at c++11 doesn't support real move capture
return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
// fetch cudnn handle from resources
auto handle = res.fetchOrStore<CudnnContext>()->handle;
// name inputs and outputs
auto x = inputs[0];
auto y = outputs[0];
// call cudnn activation
float alpha = 1, beta = 0;
CUDNN_ASSERT(cudnnActivationForward(handle, d->activation, &alpha, d->tensor, x, &beta, d->tensor, y));
};
}

#endif

}// namespace refactor::kernel
56 changes: 0 additions & 56 deletions src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cu

This file was deleted.

0 comments on commit 94fde13

Please sign in to comment.