diff --git a/src/04kernel/include/kernel/collectors/select.h b/src/04kernel/include/kernel/collectors/select.h index 97f7061c..40985a8c 100644 --- a/src/04kernel/include/kernel/collectors/select.h +++ b/src/04kernel/include/kernel/collectors/select.h @@ -10,6 +10,8 @@ namespace refactor::kernel { Min, }; + std::string_view opName(SelectType type); + struct SelectCollector final : public InfoCollector { SelectType selectType; diff --git a/src/04kernel/src/collectors/select.cc b/src/04kernel/src/collectors/select.cc index 45da6483..e4eff8f4 100644 --- a/src/04kernel/src/collectors/select.cc +++ b/src/04kernel/src/collectors/select.cc @@ -1,7 +1,27 @@ #include "kernel/collectors/select.h" +#include "../kernels/select/cpu_kernel.hh" +#include "../kernels/select/cuda_kernel.hh" namespace refactor::kernel { +#define REGISTER(T) \ + if (auto ptr = T::build(selectType, inputs); ptr) { \ + ans.emplace_back(std::move(ptr)); \ + } + +#define CASE(OP) \ + case SelectType::OP: \ + return #OP + + std::string_view opName(SelectType type) { + switch (type) { + CASE(Max); + CASE(Min); + default: + UNREACHABLE(); + } + } + SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept : InfoCollector(target), selectType(type) {} @@ -10,8 +30,10 @@ namespace refactor::kernel { std::vector ans; switch (_target) { case decltype(_target)::Cpu: + REGISTER(SelectCpu) break; case decltype(_target)::Nvidia: + REGISTER(SelectCuda) break; default: UNREACHABLEX(void, "Unknown target"); diff --git a/src/04kernel/src/kernels/concat/cuda_kernel.cc b/src/04kernel/src/kernels/concat/cuda_kernel.cc index 88035577..1f3a23eb 100644 --- a/src/04kernel/src/kernels/concat/cuda_kernel.cc +++ b/src/04kernel/src/kernels/concat/cuda_kernel.cc @@ -81,13 +81,6 @@ extern "C" __global__ void kernel( } auto segments = ss.str(); - ss.str(""); - for (auto i : range0_(inputCount)) { - ss << std::endl - << " reinterpret_cast(inputs[" << i << "]), "; - } - auto castInputs = ss.str(); - ss.str(""); ss << "Concat_" << info.blockCount << ',' << unit; for (auto seg : info.segments) { diff --git a/src/04kernel/src/kernels/select/cuda_kernel.cc b/src/04kernel/src/kernels/select/cuda_kernel.cc new file mode 100644 index 00000000..9e44d8ec --- /dev/null +++ b/src/04kernel/src/kernels/select/cuda_kernel.cc @@ -0,0 +1,162 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include "kernel/cuda/threads_distributer.cuh" +#endif + +namespace refactor::kernel { + using K = SelectCuda; + + K::SelectCuda(decltype(dataType) dataType_, + decltype(selectType) selectType_, + decltype(broadcaster) broadcaster_, + decltype(inputsNum) inputsNum_) noexcept + : dataType(dataType_), + selectType(selectType_), + broadcaster(broadcaster_), + inputsNum(inputsNum_) {} + + auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + return std::make_unique(inputs_[0].get().dataType, selectType_, Broadcaster(inputs_), inputs_.size()); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing select operation on Nvidia GPU"; + } + +#ifdef USE_CUDA + + constexpr static const char *NO_BROADCAST = R"~( +struct Inputs {{ + {dt} const *const addr[{inputsNum}]; +}}; + +__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{ + return {op}; +}} + +extern "C" __global__ void kernel( + {dt} *__restrict__ output, + Inputs inputs +) {{ + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < {n}; + tid += step) {{ + output[tid] = inputs.addr[0][tid]; + for (auto idx = 1; idx < {inputsNum}; ++idx) {{ + output[tid] = fn(inputs.addr[idx][tid], output[tid]); + }} + }} +}} +)~"; + + constexpr static const char *BROADCAST = R"~( +struct Inputs {{ + {dt} const *const addr[{inputsNum}]; +}}; + +struct Strides {{ + unsigned int s[({inputsNum}+1) * {rank}]; +}}; + +__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{ + return {op}; +}} + +extern "C" __global__ void kernel( + {dt} *__restrict__ output, + Inputs inputs, + Strides strides +) {{ + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < {n}; + tid += step) {{ + auto rem = tid; + size_t ans[{inputsNum}]{{}}; + for (auto i = 0; i < {rank}; ++i) {{ + auto dim = strides.s + ({inputsNum} + 1) * i; + auto quot = rem / dim[{inputsNum}]; + for (auto j = 0; j < {inputsNum}; ++j) {{ ans[j] += dim[j] * quot; }} + rem %= dim[{inputsNum}]; + }} + output[tid] = inputs.addr[0][ans[0]]; + for (auto idx = 1; idx < {inputsNum}; ++idx) {{ + output[tid] = fn(inputs.addr[idx][ans[idx]], output[tid]); + }} + }} +}} +)~"; + + constexpr static std::string_view op(SelectType op, DataType dt) { + switch (op) { + case SelectType::Max: + return "a > b ? a : b"; + case SelectType::Min: + return "a < b ? a : b"; + default: + UNREACHABLE(); + } + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + + auto postfix = fmt::format("_{}_{}", dataType.name(), opName(selectType)); + auto dt_ = nvrtc::dataType(dataType); + auto op_ = op(selectType, dataType); + auto params = cuda::ThreadsDistributer()(broadcaster.outputsCount); + + if (!broadcaster.needBroadcast()) { + auto name = fmt::format("select{}", postfix); + auto code = fmt::format(NO_BROADCAST, + fmt::arg("dt", dt_), + fmt::arg("op", op_), + fmt::arg("inputsNum", inputsNum), + fmt::arg("n", params.n)); + return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto output = outputs[0]; + void *args[]{&output, const_cast(inputs)}; + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); + }; + } else { + auto name = fmt::format("select{}", postfix); + auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); + auto code = fmt::format( + BROADCAST, + fmt::arg("dt", dt_), + fmt::arg("op", op_), + fmt::arg("inputsNum", inputsNum), + fmt::arg("n", params.n), + fmt::arg("rank", rank)); + return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), + strides = broadcaster.strides]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + void *args[]{const_cast(outputs), const_cast(inputs), const_cast(strides.data())}; + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); + }; + } + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/select/cuda_kernel.hh b/src/04kernel/src/kernels/select/cuda_kernel.hh new file mode 100644 index 00000000..c22de6ad --- /dev/null +++ b/src/04kernel/src/kernels/select/cuda_kernel.hh @@ -0,0 +1,31 @@ +#ifndef KERNEL_SELECT_CUDA_KERNEL_HH +#define KERNEL_SELECT_CUDA_KERNEL_HH + +#include "kernel/attributes/broadcaster.h" +#include "kernel/collectors/select.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct SelectCuda final : public Kernel { + DataType dataType; + SelectType selectType; + Broadcaster broadcaster; + size_t inputsNum; + + SelectCuda(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept; + + static KernelBox build(SelectType, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SELECT_CUDA_KERNEL_HH diff --git a/src/04kernel/test/kernels/select/test_cuda.cpp b/src/04kernel/test/kernels/select/test_cuda.cpp new file mode 100644 index 00000000..a1640ffc --- /dev/null +++ b/src/04kernel/test/kernels/select/test_cuda.cpp @@ -0,0 +1,96 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/select/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testSelect(const SelectType selectType, const std::vector &shapes, const Shape &outShape, const std::vector> &data, + const std::vector expectData) { + // build routine + TensorRefs dataTensors; + std::vector tensorsVec; + for (size_t i = 0; i < shapes.size(); ++i) { + tensorsVec.push_back(Tensor(DataType::F32, shapes[i], LayoutType::Others, nullptr)); + } + for (size_t i = 0; i < shapes.size(); ++i) { + dataTensors.push_back(std::cref(tensorsVec[i])); + } + auto result = Tensor::share(DataType::F32, outShape); + auto kernel = SelectCuda::build(selectType, dataTensors); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // cuda malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + Arc + gpuIns[]{ + dev.malloc(dataTensors[0].get().bytesSize()), + dev.malloc(dataTensors[1].get().bytesSize()), + dev.malloc(dataTensors[2].get().bytesSize()), + }, + gpuOut = dev.malloc(result->bytesSize()); + // put input data + gpuIns[0]->copyFromHost(data[0].data(), dataTensors[0].get().bytesSize()); + gpuIns[1]->copyFromHost(data[1].data(), dataTensors[1].get().bytesSize()); + gpuIns[2]->copyFromHost(data[2].data(), dataTensors[2].get().bytesSize()); + // inference + { + void const *inputs[]{*gpuIns[0], *gpuIns[1], *gpuIns[2]}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector out(result->elementsSize()); + gpuOut->copyToHost(out.data(), result->bytesSize()); + for (auto i : range0_(expectData.size())) { + EXPECT_FLOAT_EQ(expectData[i], out[i]); + } +} + +TEST(kernel, SelectCuda) { + // no need broadcast + testSelect(SelectType::Max, + {{1, 3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1, 3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {1, 2, 1}); + + // need broadcast + testSelect(SelectType::Max, + {{3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); + + testSelect(SelectType::Max, + {{1}, {1, 3}, {1, 3}}, + {1, 3}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1}, {1, 3}, {1, 3}}, + {1, 3}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); +} + +#endif