Skip to content

Commit

Permalink
refactor(kernel): 基于代码生成实现单目算子,支持更多类型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 8, 2023
1 parent 41d34b1 commit 81789e9
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 144 deletions.
1 change: 1 addition & 0 deletions src/00common/include/common/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace refactor {
bool isFloat() const noexcept;
bool isSignedLarge() const noexcept;
bool isSigned() const noexcept;
bool isUnsigned() const noexcept;
bool isNumberic() const noexcept;
bool isCpuNumberic() const noexcept;
bool isBool() const noexcept;
Expand Down
5 changes: 5 additions & 0 deletions src/00common/src/data_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ namespace refactor {
DT::I8, DT::I16, DT::I32, DT::I64};
return set.contains(internal);
}
bool DT::isUnsigned() const noexcept {
static const std::unordered_set<Enum> set{
DT::U8, DT::U16, DT::U32, DT::U64};
return set.contains(internal);
}
bool DT::isNumberic() const noexcept {
static const std::unordered_set<Enum> set{
DT::F32, DT::U8, DT::I8, DT::U16, DT::I16,
Expand Down
4 changes: 3 additions & 1 deletion src/04kernel/include/kernel/collectors/simple_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace refactor::kernel {

enum class SimpleUnaryType {
enum class SimpleUnaryType : uint8_t {
Abs,
Acos,
Acosh,
Expand All @@ -27,6 +27,8 @@ namespace refactor::kernel {
Not,
};

std::string_view unaryName(SimpleUnaryType type);

struct SimpleUnaryCollector final : public InfoCollector {
SimpleUnaryType type;

Expand Down
30 changes: 30 additions & 0 deletions src/04kernel/src/collectors/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,36 @@

namespace refactor::kernel {

#define CASE(OP) \
case SimpleUnaryType::OP: \
return #OP

std::string_view unaryName(SimpleUnaryType type) {
switch (type) {
CASE(Abs);
CASE(Acos);
CASE(Acosh);
CASE(Asin);
CASE(Asinh);
CASE(Atan);
CASE(Atanh);
CASE(Cos);
CASE(Cosh);
CASE(Sin);
CASE(Sinh);
CASE(Tan);
CASE(Tanh);
CASE(Relu);
CASE(Sqrt);
CASE(Sigmoid);
CASE(Erf);
CASE(Neg);
CASE(Not);
default:
UNREACHABLE();
}
}

#define REGISTER(T) \
if (auto ptr = T::build(type, a); ptr) { \
ans.emplace_back(std::move(ptr)); \
Expand Down
163 changes: 154 additions & 9 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#include "cuda_kernel.hh"
#include <execution>
#include <unordered_set>

#ifdef USE_CUDA
#include "../../generator/nvrtc_repo.h"
#include "kernel/cuda/threads_distributer.cuh"
#include <cuda_runtime.h>
#endif

namespace refactor::kernel {
using K = SimpleUnaryCuda;
using Op = SimpleUnaryType;
Expand All @@ -11,14 +16,13 @@ namespace refactor::kernel {
: Kernel(), dataType(dataType_), opType(opType_), size(size_) {}

auto K::build(Op op, Tensor const &a) noexcept -> KernelBox {
static const std::unordered_set<Op> supportedOp{
Op::Abs,
Op::Relu,
Op::Sqrt,
Op::Sigmoid,
Op::Tanh,
Op::Neg,
};
static const std::unordered_set<Op>
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
Op::Sigmoid, Op::Tanh, Op::Neg};
#ifndef USE_CUDA
return nullptr;
#endif

return supportedOp.contains(op) && a.dataType.isCpuNumberic()
? std::make_unique<K>(op, a.dataType, a.elementsSize())
: nullptr;
Expand All @@ -35,4 +39,145 @@ namespace refactor::kernel {
return "Performing unary operation on Nvidia GPU";
}

#ifdef USE_CUDA

constexpr static const char *TEMPLATE = R"~(
__device__ __forceinline__ static {0:} fn({0:} x) {{
return {1:};
}}
extern "C" __global__ void kernel(void *output, void const *input, size_t n) {{
auto dst = reinterpret_cast<{0:} *>(output);
auto src = reinterpret_cast<{0:} const *>(input);
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step)
dst[tid] = fn(src[tid]);
}}
)~";

constexpr uint16_t __(Op op, DT dt) {
union code {
struct {
Op op;
DT dt;
} info;
uint16_t u16;
} c{.info{op, dt}};
return c.u16;
}

auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace runtime;

if (dataType.isUnsigned()) {
switch (opType) {
case Op::Abs:
case Op::Relu:
return [n = size * dataType.size()](Resources &, void *, void const *const *inputs, void *const *outputs) {
if (outputs[0] != inputs[0]) {
cudaMemcpyAsync(outputs[0], inputs[0], n, cudaMemcpyDeviceToDevice);
}
};
case Op::Neg:
UNREACHABLE();
default:
break;
}
}

// clang-format off
static const std::unordered_map<uint8_t, std::string_view> dt {
{DT::U8 , "unsigned char" },
{DT::U16 , "unsigned short" },
{DT::U32 , "unsigned int" },
{DT::U64 , "unsigned long long"},
{DT::I8 , "char" },
{DT::I16 , "short" },
{DT::I32 , "int" },
{DT::I64 , "long long" },
{DT::FP16, "half" },
{DT::BF16, "nv_bfloat16" },
{DT::F32 , "float" },
{DT::F64 , "double" },
};
// see <https://docs.nvidia.com/cuda/cuda-math-api/index.html>.
static const std::unordered_map<uint16_t, std::string_view> op {
{__(Op::Abs, DT::I8 ), "x >= 0 ? x : -x"},
{__(Op::Abs, DT::I16 ), "x >= 0 ? x : -x"},
{__(Op::Abs, DT::I32 ), "x >= 0 ? x : -x"},
{__(Op::Abs, DT::I64 ), "x >= 0 ? x : -x"},
{__(Op::Abs, DT::FP16), "habs(x)" },
{__(Op::Abs, DT::BF16), "habs(x)" },
{__(Op::Abs, DT::F32 ), "fabsf(x)" },
{__(Op::Abs, DT::F64 ), "fabs(x)" },

{__(Op::Relu, DT::I8 ), "x > 0 ? x : 0"},
{__(Op::Relu, DT::I16 ), "x > 0 ? x : 0"},
{__(Op::Relu, DT::I32 ), "x > 0 ? x : 0"},
{__(Op::Relu, DT::I64 ), "x > 0 ? x : 0"},
{__(Op::Relu, DT::FP16), "x > CUDART_ZERO_FP16 ? x: CUDART_ZERO_FP16"},
{__(Op::Relu, DT::BF16), "x > CUDART_ZERO_BF16 ? x: CUDART_ZERO_BF16"},
{__(Op::Relu, DT::F32 ), "x > 0 ? x : 0"},
{__(Op::Relu, DT::F64 ), "x > 0 ? x : 0"},

{__(Op::Sqrt, DT::U8 ), "__fsqrt_rn(static_cast<float>(x))" },
{__(Op::Sqrt, DT::U16 ), "__fsqrt_rn(static_cast<float>(x))" },
{__(Op::Sqrt, DT::U32 ), "__dsqrt_rn(static_cast<double>(x))" },
{__(Op::Sqrt, DT::U64 ), "__dsqrt_rn(static_cast<double>(x))" },
{__(Op::Sqrt, DT::I8 ), "__fsqrt_rn(static_cast<float>(x))" },
{__(Op::Sqrt, DT::I16 ), "__fsqrt_rn(static_cast<float>(x))" },
{__(Op::Sqrt, DT::I32 ), "__dsqrt_rn(static_cast<double>(x))" },
{__(Op::Sqrt, DT::I64 ), "__dsqrt_rn(static_cast<double>(x))" },
{__(Op::Sqrt, DT::FP16), "hsqrt(x)" },
{__(Op::Sqrt, DT::BF16), "hsqrt(x)" },
{__(Op::Sqrt, DT::F32 ), "__fsqrt_rn(x)" },
{__(Op::Sqrt, DT::F64 ), "__dsqrt_rn(x)" },

{__(Op::Sigmoid, DT::U8 ), "fdividef(1, 1 + expf(-static_cast<float>(x)))" },
{__(Op::Sigmoid, DT::U16 ), "fdividef(1, 1 + expf(-static_cast<float>(x)))" },
{__(Op::Sigmoid, DT::U32 ), "1.0 / (1 + exp(-static_cast<double>(x)))" },
{__(Op::Sigmoid, DT::U64 ), "1.0 / (1 + exp(-static_cast<double>(x)))" },
{__(Op::Sigmoid, DT::I8 ), "fdividef(1, 1 + expf(-static_cast<float>(x)))" },
{__(Op::Sigmoid, DT::I16 ), "fdividef(1, 1 + expf(-static_cast<float>(x)))" },
{__(Op::Sigmoid, DT::I32 ), "1.0 / (1 + exp(-static_cast<double>(x)))" },
{__(Op::Sigmoid, DT::I64 ), "1.0 / (1 + exp(-static_cast<double>(x)))" },
{__(Op::Sigmoid, DT::FP16), "hrcp(CUDART_ONE_FP16 + hexp(-x))" },
{__(Op::Sigmoid, DT::BF16), "hrcp(CUDART_ONE_BF16 + hexp(-x))" },
{__(Op::Sigmoid, DT::F32 ), "fdividef(1, 1 + expf(-x))" },
{__(Op::Sigmoid, DT::F64 ), "1.0 / (1 + exp(-x))" },

{__(Op::Tanh, DT::F32 ), "tanh(x)"},
{__(Op::Tanh, DT::F64 ), "tanh(x)"},

{__(Op::Neg, DT::I8 ), "-x"},
{__(Op::Neg, DT::I16 ), "-x"},
{__(Op::Neg, DT::I32 ), "-x"},
{__(Op::Neg, DT::I64 ), "-x"},
{__(Op::Neg, DT::FP16), "-x"},
{__(Op::Neg, DT::BF16), "-x"},
{__(Op::Neg, DT::F32 ), "-x"},
{__(Op::Neg, DT::F64 ), "-x"},
};
// clang-format on

auto name = fmt::format("unary_{}_{}", dataType.name(), unaryName(opType));
auto code = fmt::format(TEMPLATE, dt.at(dataType), op.at(__(opType, dataType)));
auto params = cuda::ThreadsDistributer()(size);

return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
params](Resources &, void *, void const *const *inputs, void *const *outputs) {
size_t n = params.n;
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs), &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
};
}

#endif

}// namespace refactor::kernel
133 changes: 0 additions & 133 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cu

This file was deleted.

2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/simple_unary/cuda_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace refactor::kernel {
size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const noexcept final;
RoutineWorkspace lower(Resources &) const final;
#endif
};

Expand Down
Loading

0 comments on commit 81789e9

Please sign in to comment.