Skip to content

Commit

Permalink
feat(kernel): 为 conv 提供 bias 支持
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 16, 2023
1 parent 2861890 commit db27623
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 43 deletions.
3 changes: 2 additions & 1 deletion src/04kernel/src/collectors/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ namespace refactor::kernel {
ConvCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &x = inputs[0].get();
auto const &w = inputs[1].get();
auto b = inputs.size() == 3 ? std::make_optional(inputs[2]) : std::nullopt;
auto const &y = outputs[0].get();

std::vector<KernelBox> ans;
switch (target) {
case Target::Cpu:
break;
case Target::NvidiaGpu:
if (auto ptr = ConvCudnn::build(poolAttrs, x, w, y); ptr) {
if (auto ptr = ConvCudnn::build(poolAttrs, x, w, b, y); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
Expand Down
65 changes: 40 additions & 25 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@
namespace refactor::kernel {
using K = ConvCudnn;

K::ConvCudnn(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}
K::ConvCudnn(decltype(info) info_, decltype(biasExpand) biasExpand_) noexcept
: Kernel(),
info(std::move(info_)),
biasExpand(std::move(biasExpand_)) {}

auto K::build(PoolAttributes const &poolAttributes,
Tensor const &x,
Tensor const &w,
std::optional<std::reference_wrapper<Tensor const>> b,
Tensor const &y) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

std::optional<ExpandInfo> biasExpand = std::nullopt;
if (b) {
ASSERT(b->get().shape[0] == y.shape[1], "");
std::vector<dim_t> input(y.rank(), 1);
input[1] = y.shape[1];
*biasExpand = ExpandInfo(
b->get().dataType,
slice(input.data(), input.size()),
slice(y.shape.data(), y.rank()));
}

// group is not supported
if (w.rank() != 4 || x.shape[1] != w.shape[1]) {
return nullptr;
Expand All @@ -28,29 +42,30 @@ namespace refactor::kernel {
p = poolAttributes.pads(),
s = poolAttributes.strides();
return std::make_unique<K>(decltype(info){
x.dataType,
{
static_cast<int>(x.shape[0]),
static_cast<int>(x.shape[1]),
static_cast<int>(x.shape[2]),
static_cast<int>(x.shape[3]),
},
{
static_cast<int>(w.shape[0]),
static_cast<int>(w.shape[1]),
static_cast<int>(w.shape[2]),
static_cast<int>(w.shape[3]),
},
{
static_cast<int>(y.shape[0]),
static_cast<int>(y.shape[1]),
static_cast<int>(y.shape[2]),
static_cast<int>(y.shape[3]),
},
{d[0], d[1]},
{p[0], p[1]},
{s[0], s[1]}});
}
x.dataType,
{
static_cast<int>(x.shape[0]),
static_cast<int>(x.shape[1]),
static_cast<int>(x.shape[2]),
static_cast<int>(x.shape[3]),
},
{
static_cast<int>(w.shape[0]),
static_cast<int>(w.shape[1]),
static_cast<int>(w.shape[2]),
static_cast<int>(w.shape[3]),
},
{
static_cast<int>(y.shape[0]),
static_cast<int>(y.shape[1]),
static_cast<int>(y.shape[2]),
static_cast<int>(y.shape[3]),
},
{d[0], d[1]},
{p[0], p[1]},
{s[0], s[1]}},
std::move(biasExpand));
}// namespace refactor::kernel

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
Expand Down
34 changes: 19 additions & 15 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,25 @@ namespace refactor::kernel {
CUDNN_ASSERT(cudnnSetConvolution2dDescriptor(d->conv, pp[0], pp[1], ss[0], ss[1], dd[0], dd[1], CUDNN_CROSS_CORRELATION, cudnnDataType));

auto handle = res.fetchOrStore<CudnnContext>()->handle;
int returnedAlgoCount;
cudnnConvolutionFwdAlgoPerf_t perfResults;
CUDNN_ASSERT(cudnnFindConvolutionForwardAlgorithm(
handle,
d->x, d->w, d->conv, d->y,
1, &returnedAlgoCount, &perfResults));
ASSERT(returnedAlgoCount == 1, "returnedAlgoCount != 1");
// for high accuracy, use this algo only
// d->algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
d->algo = perfResults.algo;
CUDNN_ASSERT(cudnnGetConvolutionForwardWorkspaceSize(
handle,
d->x, d->w, d->conv, d->y,
perfResults.algo,
&d->workspaceSize));
{
int returnedAlgoCount;
cudnnConvolutionFwdAlgoPerf_t perfResults;
CUDNN_ASSERT(cudnnFindConvolutionForwardAlgorithm(
handle,
d->x, d->w, d->conv, d->y,
1, &returnedAlgoCount, &perfResults));
ASSERT(returnedAlgoCount == 1, "returnedAlgoCount != 1");
d->algo = perfResults.algo;
// for high accuracy, use this algo only
// d->algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
{
CUDNN_ASSERT(cudnnGetConvolutionForwardWorkspaceSize(
handle,
d->x, d->w, d->conv, d->y,
d->algo,
&d->workspaceSize));
}
// nvcc at c++11 doesn't support real move capture
return [d_ = std::move(d)](Resources &res, void const **inputs, void **outputs) {
using mem_manager::ForeignBlob;
Expand Down
7 changes: 5 additions & 2 deletions src/04kernel/src/kernels/conv/cudnn_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef KERNEL_CONV_CUDNN_KERNEL_HH
#define KERNEL_CONV_CUDNN_KERNEL_HH

#include "kernel/attributes/expand_info.h"
#include "kernel/attributes/pool_attributes.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
#include <optional>

namespace refactor::kernel {

Expand All @@ -19,12 +20,14 @@ namespace refactor::kernel {
pad[2],
stride[2];
} info;
std::optional<ExpandInfo> biasExpand;

explicit ConvCudnn(decltype(info)) noexcept;
explicit ConvCudnn(decltype(info), decltype(biasExpand)) noexcept;

static KernelBox build(PoolAttributes const &,
Tensor const &,
Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Tensor const &) noexcept;
static size_t typeId() noexcept;

Expand Down

0 comments on commit db27623

Please sign in to comment.