From cf4e92c540aff02bc7f03d3b11d41d5ea345f7ae Mon Sep 17 00:00:00 2001 From: wangw <271502003@qq.com> Date: Mon, 6 May 2024 10:23:43 +0800 Subject: [PATCH] add topk cuda kernel --- .../cuda/include/kernel/cuda/topk.cuh | 19 ++++ src/04kernel/cuda/src/topk.cu | 103 ++++++++++++++++++ .../include/kernel/attributes/topk_info.h | 21 ++-- src/04kernel/include/kernel/collectors/topk.h | 4 +- src/04kernel/src/attributes/topk_info.cc | 16 +-- src/04kernel/src/kernels/topk/cpu_kernel.cc | 30 ++--- src/04kernel/src/kernels/topk/cuda_kernel.cc | 57 ++++++++++ src/04kernel/src/kernels/topk/cuda_kernel.hh | 26 +++++ src/04kernel/test/kernels/topk/test_cpu.cpp | 4 +- src/04kernel/test/kernels/topk/test_cuda.cpp | 68 ++++++++++++ .../include/computation/operators/topk.h | 4 +- src/07onnx/src/operators/topk.cc | 2 +- 12 files changed, 314 insertions(+), 40 deletions(-) create mode 100644 src/04kernel/cuda/include/kernel/cuda/topk.cuh create mode 100644 src/04kernel/cuda/src/topk.cu create mode 100644 src/04kernel/src/kernels/topk/cuda_kernel.cc create mode 100644 src/04kernel/src/kernels/topk/cuda_kernel.hh create mode 100644 src/04kernel/test/kernels/topk/test_cuda.cpp diff --git a/src/04kernel/cuda/include/kernel/cuda/topk.cuh b/src/04kernel/cuda/include/kernel/cuda/topk.cuh new file mode 100644 index 00000000..b06cfc00 --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/topk.cuh @@ -0,0 +1,19 @@ +#ifndef KERNEL_CUDA_TOPK_CUH +#define KERNEL_CUDA_TOPK_CUH + +#include "threads_distributer.cuh" + +namespace refactor::kernel::cuda { + + void launchTopK( + KernelLaunchParameters const ¶ms, + float const *data, float *dstVal, unsigned int *dstIdx, + unsigned int topk, + unsigned int stride_axis, + unsigned int stride_in_pre, + unsigned int stride_out_pre, + unsigned int size_axis); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_TOPK_CUH diff --git a/src/04kernel/cuda/src/topk.cu b/src/04kernel/cuda/src/topk.cu new file mode 100644 index 00000000..6b247ead --- /dev/null +++ b/src/04kernel/cuda/src/topk.cu @@ -0,0 +1,103 @@ +#include "kernel/cuda/topk.cuh" +#include "macro.cuh" +#include +#include +#include + +namespace refactor::kernel::cuda { + +using PairType = thrust::pair; + +struct ComparePair { + __host__ __device__ + bool operator()(const PairType& a, const PairType& b) const { + return a.first > b.first; + } +}; + +/* + __device__ + void process_element(unsigned int n, float *__restrict__ dstVal, + uint32_t *__restrict__ dstIdx, + PairType *list, + uint32_t stride_axis, + uint32_t init_offset){ + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + uint32_t offset = init_offset + stride_axis * tid; + dstVal[offset] = list[tid].first; + dstIdx[offset] = list[tid].second; + } + } +*/ + + + + __global__ static void TopKKernel( + unsigned long long n, + float const *__restrict__ data, + float *__restrict__ dstVal, + uint32_t *__restrict__ dstIdx, + uint32_t topk, + uint32_t stride_axis, + uint32_t stride_in_pre, + uint32_t stride_out_pre, + unsigned int size) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + PairType *list = new PairType[size]; + + for(uint32_t i = 0; i < size; i++){ + uint32_t srcIdx = tid /stride_axis * stride_in_pre + tid % stride_axis + i * stride_axis; + + list[i] = PairType(data[srcIdx], i); + } + // thrust没有partial_sort算法,可尝试优化:分成size/topk组,每组取一个最大值 + thrust::sort(thrust::device, list, list + size, ComparePair()); + + + uint32_t init_offset = tid /stride_axis * stride_out_pre + tid % stride_axis; + for (uint32_t i = 0; i < topk; i++) + { + uint32_t offset = init_offset + stride_axis * i; + dstVal[offset] = list[i].first; + dstIdx[offset] = list[i].second; + } + + delete[] list; + } + } + + + + void launchTopK( + KernelLaunchParameters const ¶ms, + float const *data, float *dstVal, uint32_t *dstIdx, + uint32_t topk, + uint32_t stride_axis, + uint32_t stride_in_pre, + uint32_t stride_out_pre, + unsigned int size_axis) { + + TopKKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + (data), + (dstVal), + (dstIdx), + topk, + stride_axis, + stride_in_pre, + stride_out_pre, + size_axis); + + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/topk_info.h b/src/04kernel/include/kernel/attributes/topk_info.h index 491810d1..5cfc5ee6 100644 --- a/src/04kernel/include/kernel/attributes/topk_info.h +++ b/src/04kernel/include/kernel/attributes/topk_info.h @@ -6,18 +6,17 @@ namespace refactor::kernel { struct TopKInfo { + struct Stride{ + dim_t axis, in_pre, out_pre; + }; + struct Size{ + dim_t axis, except_axis; + }; + uint32_t topk; + Stride stride; + Size size; - int64_t topk; - int64_t axis; - size_t in_stride, in_stride_pre_axis, out_stride_pre_axis; - size_t elem_size, axis_elem_size; - - TopKInfo(int64_t topk, int64_t axis, Tensor const &input); - size_t getElementSize() const {return elem_size;} - size_t getAxisElementSize()const { return axis_elem_size;} - size_t getInStride()const{return in_stride;} - size_t getInStridePreAxis()const{return in_stride_pre_axis;} - size_t getOutStridePreAxis()const {return out_stride_pre_axis;} + TopKInfo(uint32_t topk, uint32_t axis, Tensor const &input); }; }// namespace refactor::kernel diff --git a/src/04kernel/include/kernel/collectors/topk.h b/src/04kernel/include/kernel/collectors/topk.h index 3e0dc288..c4d8490f 100644 --- a/src/04kernel/include/kernel/collectors/topk.h +++ b/src/04kernel/include/kernel/collectors/topk.h @@ -6,9 +6,9 @@ namespace refactor::kernel { struct TopKCollector final : public InfoCollector { - int64_t topk, axis; + uint32_t topk, axis; - constexpr TopKCollector(decltype(_target) target, int64_t topk, int64_t axis_) noexcept + constexpr TopKCollector(decltype(_target) target, uint32_t topk, uint32_t axis_) noexcept : InfoCollector(target), topk(topk), axis(axis_) {} std::vector diff --git a/src/04kernel/src/attributes/topk_info.cc b/src/04kernel/src/attributes/topk_info.cc index 52032db9..532f385d 100644 --- a/src/04kernel/src/attributes/topk_info.cc +++ b/src/04kernel/src/attributes/topk_info.cc @@ -3,12 +3,14 @@ namespace refactor::kernel { -TopKInfo::TopKInfo(int64_t topk, int64_t axis, Tensor const &input):topk(topk), - axis(axis), - in_stride(input.strides()[axis]), - in_stride_pre_axis(axis == 0 ? 0 : input.strides()[axis - 1]), - out_stride_pre_axis(in_stride_pre_axis/input.shape[axis]*topk), - elem_size(input.elementsSize()), - axis_elem_size(input.shape[axis]){} +TopKInfo::TopKInfo(uint32_t topk, uint32_t axis, Tensor const &input){ + this->topk =topk; + auto tmpStride = axis == 0 ? 0 : input.strides()[axis - 1]; + this->stride = {input.strides()[axis],\ + tmpStride,\ + tmpStride/input.shape[axis]*topk}; + this->size = {input.shape[axis], \ + input.elementsSize()/input.shape[axis]}; +} } diff --git a/src/04kernel/src/kernels/topk/cpu_kernel.cc b/src/04kernel/src/kernels/topk/cpu_kernel.cc index 06e1683a..e695e3f7 100644 --- a/src/04kernel/src/kernels/topk/cpu_kernel.cc +++ b/src/04kernel/src/kernels/topk/cpu_kernel.cc @@ -1,6 +1,6 @@ #include "cpu_kernel.hh" #include -#include +#include namespace refactor::kernel { using K = TopKCpu; @@ -29,31 +29,31 @@ namespace refactor::kernel { auto src = reinterpret_cast(inputs[0]); auto dstVal = reinterpret_cast(outputs[0]);//T - auto dstIndex = reinterpret_cast(outputs[1]); + auto dstIndex = reinterpret_cast(outputs[1]); - size_t M = info.getElementSize() / info.getAxisElementSize(); - size_t N = info.getAxisElementSize(); - auto inStride1 = info.getInStridePreAxis(); - auto inStride2 = info.getInStride(); - auto outStride1 = info.getOutStridePreAxis(); - auto outStride2 = inStride2; + size_t M = info.size.except_axis; + size_t N = info.size.axis; for(size_t m = 0; m < M; m ++){ - using PairType = std::pair; - std::list list; + using PairType = std::pair; + std::vector list; for(size_t n = 0; n < N; n++){ - auto srcIdx = m /inStride2 * inStride1 + m % inStride2 + n * inStride2; + auto srcIdx = m /info.stride.axis * info.stride.in_pre + m % info.stride.axis + n * info.stride.axis; list.push_back({src[srcIdx],n}); } - list.sort([](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); + //list.sort([](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); + std::partial_sort(list.begin(), \ + list.begin() + info.topk, \ + list.end(), \ + [](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); - size_t offset = m /inStride2 * outStride1 + m % inStride2; - std::for_each_n(list.begin(), (int64_t)info.topk, + size_t offset = m /info.stride.axis * info.stride.out_pre + m % info.stride.axis; + std::for_each_n(list.begin(), (uint32_t)info.topk, [&](auto &elem) { dstVal[offset] = elem.first; dstIndex[offset] = elem.second; - offset += outStride2; + offset += info.stride.axis; }); } }; diff --git a/src/04kernel/src/kernels/topk/cuda_kernel.cc b/src/04kernel/src/kernels/topk/cuda_kernel.cc new file mode 100644 index 00000000..acfa4733 --- /dev/null +++ b/src/04kernel/src/kernels/topk/cuda_kernel.cc @@ -0,0 +1,57 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "kernel/cuda/threads_distributer.cuh" +#include "kernel/cuda/topk.cuh" +#include +#include +#include +#include +#endif + +namespace refactor::kernel { + using K = TopKCuda; + + K::TopKCuda(TopKInfo info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(TopKInfo info) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + return std::make_unique(std::move(info)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing concat operation using CUDA"; + } + +#ifdef USE_CUDA + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + //return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs){ + + //} + return [info = this->info, params = cuda::ThreadsDistributer()(info.size.except_axis)] + (Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + cuda::launchTopK( + params, + reinterpret_cast(inputs[0]), + reinterpret_cast(outputs[0]), + reinterpret_cast(outputs[1]), + info.topk, + info.stride.axis, + info.stride.in_pre, + info.stride.out_pre, + info.size.axis); + }; + } +#endif +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/topk/cuda_kernel.hh b/src/04kernel/src/kernels/topk/cuda_kernel.hh new file mode 100644 index 00000000..069bbd44 --- /dev/null +++ b/src/04kernel/src/kernels/topk/cuda_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_TOPK_CUDA_KERNEL_HH +#define KERNEL_TOPK_CUDA_KERNEL_HH + +#include "kernel/attributes/topk_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct TopKCuda final : public Kernel { + TopKInfo info; + + explicit TopKCuda(TopKInfo) noexcept; + + static KernelBox build(TopKInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_TOPK_CUDA_KERNEL_HH diff --git a/src/04kernel/test/kernels/topk/test_cpu.cpp b/src/04kernel/test/kernels/topk/test_cpu.cpp index cea4e066..b0dcaa80 100644 --- a/src/04kernel/test/kernels/topk/test_cpu.cpp +++ b/src/04kernel/test/kernels/topk/test_cpu.cpp @@ -9,7 +9,7 @@ TEST(kernel, TopKCpu) { // build routine auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4}); auto outputTensor0 = Tensor::share(DataType::F32, Shape{3, 3}); - auto outputTensor1 = Tensor::share(DataType::I64, Shape{3, 3}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{3, 3}); auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor)); ASSERT_TRUE(kernel); @@ -28,7 +28,7 @@ TEST(kernel, TopKCpu) { // check std::vector expectVal = {3,2,1,7,6,5,11,10,9}; - std::vector expectIdx = {3,2,1,3,2,1,3,2,1}; + std::vector expectIdx = {3,2,1,3,2,1,3,2,1}; std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, TopKCuda) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4}); + std::vector> outputTensors{ + Tensor::share(DataType::F32, Shape{3, 3}), + Tensor::share(DataType::U32, Shape{3, 3})}; + + auto kCpu = TopKCpu::build(TopKInfo(3,1, *inputTensor)); + auto kCuda = TopKCuda::build(TopKInfo(3,1, *inputTensor)); + ASSERT_TRUE(kCpu); + ASSERT_TRUE(kCuda); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto rCuda = kCuda->lower(res).routine; + + // device malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + Arc + gpuIn = dev.malloc(inputTensor->bytesSize()), + gpuOuts[]{ + dev.malloc(outputTensors[0]->bytesSize()), + dev.malloc(outputTensors[1]->bytesSize()), + }; + // put input data + std::vector data(inputTensor->elementsSize()); + + std::vector outCpu1(outputTensors[0]->elementsSize()); + std::vector outCpu2(outputTensors[1]->elementsSize()); + + + std::vector out1(outputTensors[0]->elementsSize()); + std::vector out2(outputTensors[1]->elementsSize()); + + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), inputTensor->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOuts[0], *gpuOuts[1]}; + rCuda(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{outCpu1.data(), outCpu2.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + + gpuOuts[0]->copyToHost(out1.data(), outputTensors[0]->bytesSize()); + EXPECT_EQ(out1, outCpu1); + gpuOuts[1]->copyToHost(out2.data(), outputTensors[1]->bytesSize()); + EXPECT_EQ(out2, outCpu2); + +} + +#endif diff --git a/src/05computation/include/computation/operators/topk.h b/src/05computation/include/computation/operators/topk.h index 8ecbdfed..d5c401f4 100644 --- a/src/05computation/include/computation/operators/topk.h +++ b/src/05computation/include/computation/operators/topk.h @@ -6,8 +6,8 @@ namespace refactor::computation { struct TopK final : public Operator { - int64_t topk, axis; - constexpr TopK(int64_t topk, int64_t axis) noexcept : topk(topk), axis(axis){} + uint32_t topk, axis; + constexpr TopK(uint32_t topk, uint32_t axis) noexcept : topk(topk), axis(axis){} static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/07onnx/src/operators/topk.cc b/src/07onnx/src/operators/topk.cc index 98653472..c1e908e6 100644 --- a/src/07onnx/src/operators/topk.cc +++ b/src/07onnx/src/operators/topk.cc @@ -40,7 +40,7 @@ namespace refactor::onnx { auto dependencies = extractDependency(inputs); ans[0] = Tensor::share(input.dataType, input.shape, dependencies); ans[0]->shape[axis_] = DimExpr(topk); - ans[1] = Tensor::share(DataType::I64, input.shape, dependencies); + ans[1] = Tensor::share(DataType::U32, input.shape, dependencies); ans[1]->shape[axis_] = DimExpr(topk); return Ok(Tensors{std::move(ans)}); }