-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2b698b3
commit cf4e92c
Showing
12 changed files
with
314 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef KERNEL_CUDA_TOPK_CUH | ||
#define KERNEL_CUDA_TOPK_CUH | ||
|
||
#include "threads_distributer.cuh" | ||
|
||
namespace refactor::kernel::cuda { | ||
|
||
void launchTopK( | ||
KernelLaunchParameters const ¶ms, | ||
float const *data, float *dstVal, unsigned int *dstIdx, | ||
unsigned int topk, | ||
unsigned int stride_axis, | ||
unsigned int stride_in_pre, | ||
unsigned int stride_out_pre, | ||
unsigned int size_axis); | ||
|
||
}// namespace refactor::kernel::cuda | ||
|
||
#endif// KERNEL_CUDA_TOPK_CUH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#include "kernel/cuda/topk.cuh" | ||
#include "macro.cuh" | ||
#include <cstdint> | ||
#include <thrust/device_vector.h> | ||
#include <thrust/sort.h> | ||
|
||
namespace refactor::kernel::cuda { | ||
|
||
using PairType = thrust::pair<float, uint32_t>; | ||
|
||
struct ComparePair { | ||
__host__ __device__ | ||
bool operator()(const PairType& a, const PairType& b) const { | ||
return a.first > b.first; | ||
} | ||
}; | ||
|
||
/* | ||
__device__ | ||
void process_element(unsigned int n, float *__restrict__ dstVal, | ||
uint32_t *__restrict__ dstIdx, | ||
PairType *list, | ||
uint32_t stride_axis, | ||
uint32_t init_offset){ | ||
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, | ||
step = blockDim.x * gridDim.x; | ||
tid < n; | ||
tid += step) { | ||
uint32_t offset = init_offset + stride_axis * tid; | ||
dstVal[offset] = list[tid].first; | ||
dstIdx[offset] = list[tid].second; | ||
} | ||
} | ||
*/ | ||
|
||
|
||
|
||
__global__ static void TopKKernel( | ||
unsigned long long n, | ||
float const *__restrict__ data, | ||
float *__restrict__ dstVal, | ||
uint32_t *__restrict__ dstIdx, | ||
uint32_t topk, | ||
uint32_t stride_axis, | ||
uint32_t stride_in_pre, | ||
uint32_t stride_out_pre, | ||
unsigned int size) { | ||
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, | ||
step = blockDim.x * gridDim.x; | ||
tid < n; | ||
tid += step) { | ||
PairType *list = new PairType[size]; | ||
|
||
for(uint32_t i = 0; i < size; i++){ | ||
uint32_t srcIdx = tid /stride_axis * stride_in_pre + tid % stride_axis + i * stride_axis; | ||
|
||
list[i] = PairType(data[srcIdx], i); | ||
} | ||
// thrust没有partial_sort算法,可尝试优化:分成size/topk组,每组取一个最大值 | ||
thrust::sort(thrust::device, list, list + size, ComparePair()); | ||
|
||
|
||
uint32_t init_offset = tid /stride_axis * stride_out_pre + tid % stride_axis; | ||
for (uint32_t i = 0; i < topk; i++) | ||
{ | ||
uint32_t offset = init_offset + stride_axis * i; | ||
dstVal[offset] = list[i].first; | ||
dstIdx[offset] = list[i].second; | ||
} | ||
|
||
delete[] list; | ||
} | ||
} | ||
|
||
|
||
|
||
void launchTopK( | ||
KernelLaunchParameters const ¶ms, | ||
float const *data, float *dstVal, uint32_t *dstIdx, | ||
uint32_t topk, | ||
uint32_t stride_axis, | ||
uint32_t stride_in_pre, | ||
uint32_t stride_out_pre, | ||
unsigned int size_axis) { | ||
|
||
TopKKernel<<< | ||
params.gridSize, | ||
params.blockSize, | ||
0, | ||
reinterpret_cast<cudaStream_t>(params.stream)>>>( | ||
params.n, | ||
(data), | ||
(dstVal), | ||
(dstIdx), | ||
topk, | ||
stride_axis, | ||
stride_in_pre, | ||
stride_out_pre, | ||
size_axis); | ||
|
||
} | ||
|
||
}// namespace refactor::kernel::cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#include "cuda_kernel.hh" | ||
|
||
#ifdef USE_CUDA | ||
#include "kernel/cuda/threads_distributer.cuh" | ||
#include "kernel/cuda/topk.cuh" | ||
#include <cuda_runtime.h> | ||
#include <sstream> | ||
#include <thrust/device_vector.h> | ||
#include <thrust/sort.h> | ||
#endif | ||
|
||
namespace refactor::kernel { | ||
using K = TopKCuda; | ||
|
||
K::TopKCuda(TopKInfo info_) noexcept | ||
: Kernel(), info(std::move(info_)) {} | ||
|
||
auto K::build(TopKInfo info) noexcept -> KernelBox { | ||
#ifndef USE_CUDA | ||
return nullptr; | ||
#endif | ||
|
||
return std::make_unique<K>(std::move(info)); | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing concat operation using CUDA"; | ||
} | ||
|
||
#ifdef USE_CUDA | ||
auto K::lower(Resources &) const noexcept -> RoutineWorkspace { | ||
//return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs){ | ||
|
||
//} | ||
return [info = this->info, params = cuda::ThreadsDistributer()(info.size.except_axis)] | ||
(Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
cuda::launchTopK( | ||
params, | ||
reinterpret_cast<float const *>(inputs[0]), | ||
reinterpret_cast<float *>(outputs[0]), | ||
reinterpret_cast<uint32_t *>(outputs[1]), | ||
info.topk, | ||
info.stride.axis, | ||
info.stride.in_pre, | ||
info.stride.out_pre, | ||
info.size.axis); | ||
}; | ||
} | ||
#endif | ||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef KERNEL_TOPK_CUDA_KERNEL_HH | ||
#define KERNEL_TOPK_CUDA_KERNEL_HH | ||
|
||
#include "kernel/attributes/topk_info.h" | ||
#include "kernel/kernel.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct TopKCuda final : public Kernel { | ||
TopKInfo info; | ||
|
||
explicit TopKCuda(TopKInfo) noexcept; | ||
|
||
static KernelBox build(TopKInfo) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
#ifdef USE_CUDA | ||
RoutineWorkspace lower(Resources &) const noexcept final; | ||
#endif | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_TOPK_CUDA_KERNEL_HH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.