Skip to content

Commit

Permalink
refactor(kernel): 简化 nvrtc kernel 调用
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 14, 2023
1 parent d27db1e commit 59939c6
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 51 deletions.
26 changes: 24 additions & 2 deletions src/04kernel/src/generator/nvrtc_repo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
nvrtcGetErrorString(status))); \
}

#define CUDA_ASSERT(CALL) \
if (auto result = CALL; result != CUDA_SUCCESS) { \
const char *msg; \
cuGetErrorName(result, &msg); \
RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \
msg, (int) result)); \
}

namespace refactor::kernel::nvrtc {

Handler::Handler(std::string_view name,
Expand Down Expand Up @@ -85,8 +93,22 @@ namespace refactor::kernel::nvrtc {
return it->second;
}

CUfunction Handler::kernel() const {
return _kernel;
void Handler::launch(unsigned int gridDimX,
unsigned int gridDimY,
unsigned int gridDimZ,
unsigned int blockDimX,
unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes,
void **kernelParams) const {
CUDA_ASSERT(cuLaunchKernel(
_kernel,
gridDimX, gridDimY, gridDimZ,
blockDimX, blockDimY, blockDimZ,
sharedMemBytes,
nullptr,
kernelParams,
nullptr));
}

std::string_view memCopyType(size_t size) {
Expand Down
17 changes: 8 additions & 9 deletions src/04kernel/src/generator/nvrtc_repo.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
#include "common.h"
#include <cuda.h>

#define CUDA_ASSERT(CALL) \
if (auto result = CALL; result != CUDA_SUCCESS) { \
const char *msg; \
cuGetErrorName(result, &msg); \
RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \
msg, (int) result)); \
}

namespace refactor::kernel::nvrtc {

class Handler {
Expand All @@ -29,7 +21,14 @@ namespace refactor::kernel::nvrtc {
std::string_view name,
std::string_view code,
std::string_view symbol);
CUfunction kernel() const;
void launch(unsigned int gridDimX,
unsigned int gridDimY,
unsigned int gridDimZ,
unsigned int blockDimX,
unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes,
void **kernelParams) const;
};

std::string_view memCopyType(size_t);
Expand Down
8 changes: 3 additions & 5 deletions src/04kernel/src/kernels/cast/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,9 @@ extern "C" __global__ void kernel(
auto x = inputs[0];
auto n = params.n;
void *args[]{&y, &x, &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}

Expand Down
8 changes: 3 additions & 5 deletions src/04kernel/src/kernels/concat/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ extern "C" __global__ void kernel(
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
params](Resources &, void *, void const *const *inputs, void *const *outputs) {
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs)};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}

Expand Down
28 changes: 12 additions & 16 deletions src/04kernel/src/kernels/simple_binary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,13 @@ extern "C" __global__ void kernel(
b = inputs[1];
auto n = params.n;
void *args[]{&c, &a, &b, &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};

} else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) {
static std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
static const std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
auto name = fmt::format("binaryScalar{}", postfix);
auto code = fmt::format(SCALAR, dt_, op_);
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
Expand All @@ -185,12 +184,11 @@ extern "C" __global__ void kernel(
v = inputs[1 - scalar];
auto n = params.n;
void *args[]{&c, &v, &s, &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};

} else {
auto name = fmt::format("binary{}{}", rank, postfix);
auto code = fmt::format(BROADCAST, dt_, op_, rank);
Expand All @@ -202,11 +200,9 @@ extern "C" __global__ void kernel(
b = inputs[1];
auto n = params.n;
void *args[]{&c, &a, &b, const_cast<dim_t *>(strides.data()), &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}
}
Expand Down
8 changes: 3 additions & 5 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,9 @@ extern "C" __global__ void kernel(
auto x = inputs[0];
auto n = params.n;
void *args[]{&y, &x, &n};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}

Expand Down
8 changes: 3 additions & 5 deletions src/04kernel/src/kernels/split/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ extern "C" __global__ void kernel(
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
params](Resources &, void *, void const *const *inputs, void *const *outputs) {
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs)};
CUDA_ASSERT(cuLaunchKernel(
h->kernel(),
params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, nullptr, args, nullptr));
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}

Expand Down
8 changes: 4 additions & 4 deletions src/04kernel/test/generator/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ extern "C" __global__ void kernel() {

TEST(generator, nvrtc) {
auto handler = nvrtc::Handler::compile("helloWorld.cu", code, "kernel");
CUDA_ASSERT(cuLaunchKernel(handler->kernel(),
1, 1, 1,
1, 1, 1,
0, nullptr, nullptr, nullptr));
handler->launch(
1, 1, 1,
1, 1, 1,
0, nullptr);
}

#endif// USE_CUDA

0 comments on commit 59939c6

Please sign in to comment.