diff --git a/src/04kernel/src/generator/nvrtc_repo.cc b/src/04kernel/src/generator/nvrtc_repo.cc index d62d88f7..1767a869 100644 --- a/src/04kernel/src/generator/nvrtc_repo.cc +++ b/src/04kernel/src/generator/nvrtc_repo.cc @@ -10,6 +10,14 @@ nvrtcGetErrorString(status))); \ } +#define CUDA_ASSERT(CALL) \ + if (auto result = CALL; result != CUDA_SUCCESS) { \ + const char *msg; \ + cuGetErrorName(result, &msg); \ + RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \ + msg, (int) result)); \ + } + namespace refactor::kernel::nvrtc { Handler::Handler(std::string_view name, @@ -85,8 +93,22 @@ namespace refactor::kernel::nvrtc { return it->second; } - CUfunction Handler::kernel() const { - return _kernel; + void Handler::launch(unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + void **kernelParams) const { + CUDA_ASSERT(cuLaunchKernel( + _kernel, + gridDimX, gridDimY, gridDimZ, + blockDimX, blockDimY, blockDimZ, + sharedMemBytes, + nullptr, + kernelParams, + nullptr)); } std::string_view memCopyType(size_t size) { diff --git a/src/04kernel/src/generator/nvrtc_repo.h b/src/04kernel/src/generator/nvrtc_repo.h index cf8b2207..3c1ee345 100644 --- a/src/04kernel/src/generator/nvrtc_repo.h +++ b/src/04kernel/src/generator/nvrtc_repo.h @@ -4,14 +4,6 @@ #include "common.h" #include -#define CUDA_ASSERT(CALL) \ - if (auto result = CALL; result != CUDA_SUCCESS) { \ - const char *msg; \ - cuGetErrorName(result, &msg); \ - RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \ - msg, (int) result)); \ - } - namespace refactor::kernel::nvrtc { class Handler { @@ -29,7 +21,14 @@ namespace refactor::kernel::nvrtc { std::string_view name, std::string_view code, std::string_view symbol); - CUfunction kernel() const; + void launch(unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + void **kernelParams) const; }; std::string_view memCopyType(size_t); diff --git a/src/04kernel/src/kernels/cast/cuda_kernel.cc b/src/04kernel/src/kernels/cast/cuda_kernel.cc index 699282f3..b2f3d773 100644 --- a/src/04kernel/src/kernels/cast/cuda_kernel.cc +++ b/src/04kernel/src/kernels/cast/cuda_kernel.cc @@ -115,11 +115,9 @@ extern "C" __global__ void kernel( auto x = inputs[0]; auto n = params.n; void *args[]{&y, &x, &n}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; } diff --git a/src/04kernel/src/kernels/concat/cuda_kernel.cc b/src/04kernel/src/kernels/concat/cuda_kernel.cc index 0dfe9ec8..88035577 100644 --- a/src/04kernel/src/kernels/concat/cuda_kernel.cc +++ b/src/04kernel/src/kernels/concat/cuda_kernel.cc @@ -107,11 +107,9 @@ extern "C" __global__ void kernel( return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), params](Resources &, void *, void const *const *inputs, void *const *outputs) { void *args[]{const_cast(outputs), const_cast(inputs)}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; } diff --git a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc index 871a4fab..3c2351c2 100644 --- a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc @@ -164,14 +164,13 @@ extern "C" __global__ void kernel( b = inputs[1]; auto n = params.n; void *args[]{&c, &a, &b, &n}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; + } else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) { - static std::vector S0{0, 1, 1}, S1{1, 0, 1}; + static const std::vector S0{0, 1, 1}, S1{1, 0, 1}; auto name = fmt::format("binaryScalar{}", postfix); auto code = fmt::format(SCALAR, dt_, op_); return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), @@ -185,12 +184,11 @@ extern "C" __global__ void kernel( v = inputs[1 - scalar]; auto n = params.n; void *args[]{&c, &v, &s, &n}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; + } else { auto name = fmt::format("binary{}{}", rank, postfix); auto code = fmt::format(BROADCAST, dt_, op_, rank); @@ -202,11 +200,9 @@ extern "C" __global__ void kernel( b = inputs[1]; auto n = params.n; void *args[]{&c, &a, &b, const_cast(strides.data()), &n}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; } } diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index 76221b71..e3c260db 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -152,11 +152,9 @@ extern "C" __global__ void kernel( auto x = inputs[0]; auto n = params.n; void *args[]{&y, &x, &n}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; } diff --git a/src/04kernel/src/kernels/split/cuda_kernel.cc b/src/04kernel/src/kernels/split/cuda_kernel.cc index e1286367..d132bd3e 100644 --- a/src/04kernel/src/kernels/split/cuda_kernel.cc +++ b/src/04kernel/src/kernels/split/cuda_kernel.cc @@ -107,11 +107,9 @@ extern "C" __global__ void kernel( return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), params](Resources &, void *, void const *const *inputs, void *const *outputs) { void *args[]{const_cast(outputs), const_cast(inputs)}; - CUDA_ASSERT(cuLaunchKernel( - h->kernel(), - params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, nullptr, args, nullptr)); + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); }; } diff --git a/src/04kernel/test/generator/test_cuda.cpp b/src/04kernel/test/generator/test_cuda.cpp index dcd847b2..662ee2a6 100644 --- a/src/04kernel/test/generator/test_cuda.cpp +++ b/src/04kernel/test/generator/test_cuda.cpp @@ -14,10 +14,10 @@ extern "C" __global__ void kernel() { TEST(generator, nvrtc) { auto handler = nvrtc::Handler::compile("helloWorld.cu", code, "kernel"); - CUDA_ASSERT(cuLaunchKernel(handler->kernel(), - 1, 1, 1, - 1, 1, 1, - 0, nullptr, nullptr, nullptr)); + handler->launch( + 1, 1, 1, + 1, 1, 1, + 0, nullptr); } #endif// USE_CUDA