diff --git a/src/00common/include/common/data_type.h b/src/00common/include/common/data_type.h index 9824f7739..ae473ad73 100644 --- a/src/00common/include/common/data_type.h +++ b/src/00common/include/common/data_type.h @@ -45,6 +45,7 @@ namespace refactor { bool isFloat() const noexcept; bool isSignedLarge() const noexcept; bool isSigned() const noexcept; + bool isUnsigned() const noexcept; bool isNumberic() const noexcept; bool isCpuNumberic() const noexcept; bool isBool() const noexcept; diff --git a/src/00common/src/data_type.cc b/src/00common/src/data_type.cc index 7c5e697ec..44925cad6 100644 --- a/src/00common/src/data_type.cc +++ b/src/00common/src/data_type.cc @@ -82,6 +82,11 @@ namespace refactor { DT::I8, DT::I16, DT::I32, DT::I64}; return set.contains(internal); } + bool DT::isUnsigned() const noexcept { + static const std::unordered_set set{ + DT::U8, DT::U16, DT::U32, DT::U64}; + return set.contains(internal); + } bool DT::isNumberic() const noexcept { static const std::unordered_set set{ DT::F32, DT::U8, DT::I8, DT::U16, DT::I16, diff --git a/src/04kernel/include/kernel/collectors/simple_unary.h b/src/04kernel/include/kernel/collectors/simple_unary.h index 1a25cf3b4..ee190ee17 100644 --- a/src/04kernel/include/kernel/collectors/simple_unary.h +++ b/src/04kernel/include/kernel/collectors/simple_unary.h @@ -5,7 +5,7 @@ namespace refactor::kernel { - enum class SimpleUnaryType { + enum class SimpleUnaryType : uint8_t { Abs, Acos, Acosh, @@ -27,6 +27,8 @@ namespace refactor::kernel { Not, }; + std::string_view unaryName(SimpleUnaryType type); + struct SimpleUnaryCollector final : public InfoCollector { SimpleUnaryType type; diff --git a/src/04kernel/src/collectors/simple_unary.cc b/src/04kernel/src/collectors/simple_unary.cc index 4dc6bfc60..de9e0bb07 100644 --- a/src/04kernel/src/collectors/simple_unary.cc +++ b/src/04kernel/src/collectors/simple_unary.cc @@ -6,6 +6,36 @@ namespace refactor::kernel { +#define CASE(OP) \ + case SimpleUnaryType::OP: \ + return #OP + + std::string_view unaryName(SimpleUnaryType type) { + switch (type) { + CASE(Abs); + CASE(Acos); + CASE(Acosh); + CASE(Asin); + CASE(Asinh); + CASE(Atan); + CASE(Atanh); + CASE(Cos); + CASE(Cosh); + CASE(Sin); + CASE(Sinh); + CASE(Tan); + CASE(Tanh); + CASE(Relu); + CASE(Sqrt); + CASE(Sigmoid); + CASE(Erf); + CASE(Neg); + CASE(Not); + default: + UNREACHABLE(); + } + } + #define REGISTER(T) \ if (auto ptr = T::build(type, a); ptr) { \ ans.emplace_back(std::move(ptr)); \ diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index ac6de21d6..bc410c1ab 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -1,7 +1,12 @@ #include "cuda_kernel.hh" -#include #include +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include "kernel/cuda/threads_distributer.cuh" +#include +#endif + namespace refactor::kernel { using K = SimpleUnaryCuda; using Op = SimpleUnaryType; @@ -11,14 +16,13 @@ namespace refactor::kernel { : Kernel(), dataType(dataType_), opType(opType_), size(size_) {} auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { - static const std::unordered_set supportedOp{ - Op::Abs, - Op::Relu, - Op::Sqrt, - Op::Sigmoid, - Op::Tanh, - Op::Neg, - }; + static const std::unordered_set + supportedOp{Op::Abs, Op::Relu, Op::Sqrt, + Op::Sigmoid, Op::Tanh, Op::Neg}; +#ifndef USE_CUDA + return nullptr; +#endif + return supportedOp.contains(op) && a.dataType.isCpuNumberic() ? std::make_unique(op, a.dataType, a.elementsSize()) : nullptr; @@ -35,4 +39,145 @@ namespace refactor::kernel { return "Performing unary operation on Nvidia GPU"; } +#ifdef USE_CUDA + + constexpr static const char *TEMPLATE = R"~( +__device__ __forceinline__ static {0:} fn({0:} x) {{ + return {1:}; +}} + +extern "C" __global__ void kernel(void *output, void const *input, size_t n) {{ + auto dst = reinterpret_cast<{0:} *>(output); + auto src = reinterpret_cast<{0:} const *>(input); + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) + dst[tid] = fn(src[tid]); +}} +)~"; + + constexpr uint16_t __(Op op, DT dt) { + union code { + struct { + Op op; + DT dt; + } info; + uint16_t u16; + } c{.info{op, dt}}; + return c.u16; + } + + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace runtime; + + if (dataType.isUnsigned()) { + switch (opType) { + case Op::Abs: + case Op::Relu: + return [n = size * dataType.size()](Resources &, void *, void const *const *inputs, void *const *outputs) { + if (outputs[0] != inputs[0]) { + cudaMemcpyAsync(outputs[0], inputs[0], n, cudaMemcpyDeviceToDevice); + } + }; + case Op::Neg: + UNREACHABLE(); + default: + break; + } + } + + // clang-format off + static const std::unordered_map dt { + {DT::U8 , "unsigned char" }, + {DT::U16 , "unsigned short" }, + {DT::U32 , "unsigned int" }, + {DT::U64 , "unsigned long long"}, + {DT::I8 , "char" }, + {DT::I16 , "short" }, + {DT::I32 , "int" }, + {DT::I64 , "long long" }, + {DT::FP16, "half" }, + {DT::BF16, "nv_bfloat16" }, + {DT::F32 , "float" }, + {DT::F64 , "double" }, + }; + // see . + static const std::unordered_map op { + {__(Op::Abs, DT::I8 ), "x >= 0 ? x : -x"}, + {__(Op::Abs, DT::I16 ), "x >= 0 ? x : -x"}, + {__(Op::Abs, DT::I32 ), "x >= 0 ? x : -x"}, + {__(Op::Abs, DT::I64 ), "x >= 0 ? x : -x"}, + {__(Op::Abs, DT::FP16), "habs(x)" }, + {__(Op::Abs, DT::BF16), "habs(x)" }, + {__(Op::Abs, DT::F32 ), "fabsf(x)" }, + {__(Op::Abs, DT::F64 ), "fabs(x)" }, + + {__(Op::Relu, DT::I8 ), "x > 0 ? x : 0"}, + {__(Op::Relu, DT::I16 ), "x > 0 ? x : 0"}, + {__(Op::Relu, DT::I32 ), "x > 0 ? x : 0"}, + {__(Op::Relu, DT::I64 ), "x > 0 ? x : 0"}, + {__(Op::Relu, DT::FP16), "x > CUDART_ZERO_FP16 ? x: CUDART_ZERO_FP16"}, + {__(Op::Relu, DT::BF16), "x > CUDART_ZERO_BF16 ? x: CUDART_ZERO_BF16"}, + {__(Op::Relu, DT::F32 ), "x > 0 ? x : 0"}, + {__(Op::Relu, DT::F64 ), "x > 0 ? x : 0"}, + + {__(Op::Sqrt, DT::U8 ), "__fsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::U16 ), "__fsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::U32 ), "__dsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::U64 ), "__dsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::I8 ), "__fsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::I16 ), "__fsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::I32 ), "__dsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::I64 ), "__dsqrt_rn(static_cast(x))" }, + {__(Op::Sqrt, DT::FP16), "hsqrt(x)" }, + {__(Op::Sqrt, DT::BF16), "hsqrt(x)" }, + {__(Op::Sqrt, DT::F32 ), "__fsqrt_rn(x)" }, + {__(Op::Sqrt, DT::F64 ), "__dsqrt_rn(x)" }, + + {__(Op::Sigmoid, DT::U8 ), "fdividef(1, 1 + expf(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::U16 ), "fdividef(1, 1 + expf(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::U32 ), "1.0 / (1 + exp(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::U64 ), "1.0 / (1 + exp(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::I8 ), "fdividef(1, 1 + expf(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::I16 ), "fdividef(1, 1 + expf(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::I32 ), "1.0 / (1 + exp(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::I64 ), "1.0 / (1 + exp(-static_cast(x)))" }, + {__(Op::Sigmoid, DT::FP16), "hrcp(CUDART_ONE_FP16 + hexp(-x))" }, + {__(Op::Sigmoid, DT::BF16), "hrcp(CUDART_ONE_BF16 + hexp(-x))" }, + {__(Op::Sigmoid, DT::F32 ), "fdividef(1, 1 + expf(-x))" }, + {__(Op::Sigmoid, DT::F64 ), "1.0 / (1 + exp(-x))" }, + + {__(Op::Tanh, DT::F32 ), "tanh(x)"}, + {__(Op::Tanh, DT::F64 ), "tanh(x)"}, + + {__(Op::Neg, DT::I8 ), "-x"}, + {__(Op::Neg, DT::I16 ), "-x"}, + {__(Op::Neg, DT::I32 ), "-x"}, + {__(Op::Neg, DT::I64 ), "-x"}, + {__(Op::Neg, DT::FP16), "-x"}, + {__(Op::Neg, DT::BF16), "-x"}, + {__(Op::Neg, DT::F32 ), "-x"}, + {__(Op::Neg, DT::F64 ), "-x"}, + }; + // clang-format on + + auto name = fmt::format("unary_{}_{}", dataType.name(), unaryName(opType)); + auto code = fmt::format(TEMPLATE, dt.at(dataType), op.at(__(opType, dataType))); + auto params = cuda::ThreadsDistributer()(size); + + return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), + params](Resources &, void *, void const *const *inputs, void *const *outputs) { + size_t n = params.n; + void *args[]{const_cast(outputs), const_cast(inputs), &n}; + CUDA_ASSERT(cuLaunchKernel( + h->kernel(), + params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, nullptr, args, nullptr)); + }; + } + +#endif + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cu b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cu deleted file mode 100644 index 8ba2ca7de..000000000 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cu +++ /dev/null @@ -1,133 +0,0 @@ -#include "cuda_kernel.hh" -#include -#include -#include - -namespace refactor::kernel { - using K = SimpleUnaryCuda; - using Op = SimpleUnaryType; - using DT = DataType; - - template struct AbsFunctor { - __device__ T operator()(T x) const { return abs(x); } - }; - template struct NegFunctor { - __device__ T operator()(T x) const { return -x; } - }; - template struct ReluFunctor { - __device__ T operator()(T x) const { return x > 0 ? x : 0; } - }; - template struct SqrtFunctor { - __device__ T operator()(T x) const { - using M = std::conditional_t; - return static_cast(sqrt(static_cast(x))); - } - }; - template struct SigmoidFunctor { - __device__ T operator()(T x) const { - using M = std::conditional_t; - return static_cast(1 / (1 + std::exp(-static_cast(x)))); - } - }; - template struct TanhFunctor { - __device__ T operator()(T x) const { - using M = std::conditional_t; - return static_cast(tanh(static_cast(x))); - } - }; - - template - static auto lowerTyped(size_t n) noexcept { - using namespace runtime; - - return [n](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { - auto const *x = static_cast(inputs[0]); - auto *y = static_cast(outputs[0]); - thrust::transform(thrust::device, x, x + n, y, UnaryFuntor{}); - }; - } - -#define CASE(FUNC, TYPE) \ - case DT::TYPE: \ - return lowerTyped::type, FUNC##Functor::type>>(size) -#define COPY \ - return [size = size * dataType.size()](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { \ - cudaMemcpyAsync(outputs[0], inputs[0], size, cudaMemcpyDeviceToDevice); \ - } -#define GROUP_F(FUNC) \ - CASE(FUNC, F32); \ - CASE(FUNC, F64) -#define GROUP_I(FUNC) \ - CASE(FUNC, I8); \ - CASE(FUNC, I16); \ - CASE(FUNC, I32); \ - CASE(FUNC, I64) -#define GROUP_U(FUNC) \ - CASE(FUNC, U8); \ - CASE(FUNC, U16); \ - CASE(FUNC, U32); \ - CASE(FUNC, U64) - - auto K::lower(Resources &) const noexcept -> RoutineWorkspace { - switch (opType) { - case Op::Abs: - switch (dataType) { - GROUP_F(Abs); - GROUP_I(Abs); - case DT::U8: - case DT::U16: - case DT::U32: - case DT::U64: - COPY; - default: - UNREACHABLE(); - } - case Op::Relu: - switch (dataType) { - GROUP_F(Relu); - GROUP_I(Relu); - case DT::U8: - case DT::U16: - case DT::U32: - case DT::U64: - COPY; - default: - UNREACHABLE(); - } - case Op::Sqrt: - switch (dataType) { - GROUP_F(Sqrt); - GROUP_I(Sqrt); - GROUP_U(Sqrt); - default: - UNREACHABLE(); - } - case Op::Sigmoid: - switch (dataType) { - GROUP_F(Sigmoid); - GROUP_I(Sigmoid); - GROUP_U(Sigmoid); - default: - UNREACHABLE(); - } - case Op::Tanh: - switch (dataType) { - GROUP_F(Tanh); - GROUP_I(Tanh); - GROUP_U(Tanh); - default: - UNREACHABLE(); - } - case Op::Neg: - switch (dataType) { - GROUP_F(Neg); - GROUP_I(Neg); - default: - UNREACHABLE(); - } - default: - UNREACHABLE(); - } - } - -}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.hh b/src/04kernel/src/kernels/simple_unary/cuda_kernel.hh index e5308b63b..e9a0e8a5e 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.hh +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.hh @@ -19,7 +19,7 @@ namespace refactor::kernel { size_t kernelTypeId() const noexcept final; std::string_view description() const noexcept final; #ifdef USE_CUDA - RoutineWorkspace lower(Resources &) const noexcept final; + RoutineWorkspace lower(Resources &) const final; #endif }; diff --git a/src/04kernel/src/kernels/split/cuda_kernel.cc b/src/04kernel/src/kernels/split/cuda_kernel.cc index 3b1e42914..ac5bbeafc 100644 --- a/src/04kernel/src/kernels/split/cuda_kernel.cc +++ b/src/04kernel/src/kernels/split/cuda_kernel.cc @@ -17,6 +17,7 @@ namespace refactor::kernel { #ifndef USE_CUDA return nullptr; #endif + return std::make_unique(std::move(info)); } auto K::typeId() noexcept -> size_t {