Skip to content

Commit

Permalink
feat(kernel): 为 conv 支持不对称 padding
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 1, 2023
1 parent a8f52fc commit f9d5015
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 27 deletions.
14 changes: 14 additions & 0 deletions src/04kernel/src/attributes/pad_2d_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "pad_2d_info.h"
#include <numeric>

namespace refactor::kernel {

Pad2DInfo::Pad2DInfo(DataType dt, slice_t<dim_t> input, ddim_t const *pads)
: blockCount(std::accumulate(input.begin(), input.end() - 2, 1, std::multiplies<>())),
blockSize(dt.size()),
hw(input.end()[-1] * input.end()[-2]),
w(input.end()[-1]),
padHW(pads[0] - pads[2]),
padW(pads[1] - pads[3]) {}

}// namespace refactor::kernel
17 changes: 17 additions & 0 deletions src/04kernel/src/attributes/pad_2d_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef KERNEL_PAD_2D_INFO_H
#define KERNEL_PAD_2D_INFO_H

#include "kernel/tensor.h"

namespace refactor::kernel {
/// @brief 优化用于计算的 Slice 描述。
struct Pad2DInfo {
dim_t blockCount, blockSize, hw, w;
ddim_t padHW, padW;

Pad2DInfo(DataType, slice_t<dim_t> input, ddim_t const *pads);
};

}// namespace refactor::kernel

#endif// KERNEL_PAD_2D_INFO_H
10 changes: 2 additions & 8 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ namespace refactor::kernel {
}

// group is not supported
if (w.rank() != 4 || x.shape[1] != w.shape[1]) {
return nullptr;
}
auto padsBegin = poolAttributes.padsBegin(),
padsEnd = poolAttributes.padsEnd();
if (padsBegin[0] != padsEnd[0] ||
padsBegin[1] != padsEnd[1]) {
if (w.rank() != 4 || poolAttributes.rank() != 2) {
return nullptr;
}
auto d = poolAttributes.dilations(),
Expand All @@ -60,7 +54,7 @@ namespace refactor::kernel {
static_cast<int>(y.shape[3]),
},
{d[0], d[1]},
{p[0], p[1]},
{p[0], p[1], p[2], p[3]},
{s[0], s[1]},
std::move(biasExpand),
});
Expand Down
110 changes: 92 additions & 18 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,81 @@
#include "../../utilities/cuda/cudnn_functions.h"
#include "../expand/cuda_kernel.hh"
#include "cudnn_kernel.hh"
#include "hardware/functions.h"
#include <thrust/execution_policy.h>
#include <thrust/tabulate.h>

namespace refactor::kernel {
using namespace cudnn;
using namespace runtime;

struct ExtraPadding {
DataType dt;
int nc, sohw, sow, h, w, padH, padW;

static std::optional<ExtraPadding> build(DataType dt, int const *shape, int const *pads) {
if (pads[0] == pads[2] && pads[1] == pads[3]) {
return std::nullopt;
}
int padH = pads[0] - pads[2], padW = pads[1] - pads[3];
return ExtraPadding{
dt,
shape[0] * shape[1],
(shape[2] + std::abs(padH)) * (shape[3] + std::abs(padW)),
shape[3] + std::abs(padW),
shape[2],
shape[3],
padH,
padW};
}

size_t workspace() const {
return nc * sohw * dt.size();
}
};

template<class T>
struct ExtraPaddingFunctor {
ExtraPadding info;
void const *src;

__device__ T operator()(size_t i) const noexcept {
auto h = i / info.sow,
w = i % info.sow;
if (0 < info.padH) {
if (h < info.padH) {
return 0;
}
h -= info.padH;
} else if (h >= info.h) {
return 0;
}
if (0 < info.padW) {
if (w < info.padW) {
return 0;
}
w -= info.padW;
} else if (w >= info.w) {
return 0;
}
return reinterpret_cast<T const *>(src)[i / info.sohw * info.h * info.w + h * info.w + w];
}
};

auto ConvCudnn::lower(Resources &res) const -> RoutineWorkspace {
// RAII for closure
struct Descriptors {
cudnnTensorDescriptor_t x, y;
cudnnFilterDescriptor_t w;
cudnnConvolutionDescriptor_t conv;
cudnnConvolutionFwdAlgo_t algo;
std::optional<ExtraPadding> extraPadding;
std::optional<Routine> biasExpand;
bool f64;

Descriptors(bool f64_) : biasExpand(std::nullopt), f64(f64_) {
Descriptors(bool f64_) : extraPadding(std::nullopt),
biasExpand(std::nullopt),
f64(f64_) {
CUDNN_ASSERT(cudnnCreateTensorDescriptor(&x));
CUDNN_ASSERT(cudnnCreateTensorDescriptor(&y));
CUDNN_ASSERT(cudnnCreateFilterDescriptor(&w));
Expand All @@ -34,6 +93,7 @@ namespace refactor::kernel {
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>(info.dt == DataType::F64);
d->extraPadding = ExtraPadding::build(info.dt, info.xShape, info.pad);
if (info.biasExpand) {
d->biasExpand = ExpandCuda(*info.biasExpand).lower(res).routine;
}
Expand All @@ -46,7 +106,13 @@ namespace refactor::kernel {
auto pp = info.pad;
auto ss = info.stride;
auto dd = info.dilation;
CUDNN_ASSERT(cudnnSetConvolution2dDescriptor(d->conv, pp[0], pp[1], ss[0], ss[1], dd[0], dd[1], CUDNN_CROSS_CORRELATION, cudnnDataType));
CUDNN_ASSERT(cudnnSetConvolution2dDescriptor(
d->conv,
std::min(pp[0], pp[2]), std::min(pp[1], pp[3]),
ss[0], ss[1],
dd[0], dd[1],
CUDNN_CROSS_CORRELATION,
cudnnDataType));

if (auto group = info.xShape[1] / ws[1]; group > 1) {
CUDNN_ASSERT(cudnnSetConvolutionGroupCount(d->conv, group));
Expand All @@ -63,7 +129,7 @@ namespace refactor::kernel {
ASSERT(returnedAlgoCount == 1, "returnedAlgoCount != 1");
d->algo = perfResults.algo;
// for high accuracy, use this algo only
d->algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
// d->algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
size_t workspaceSize;
{
Expand All @@ -73,41 +139,49 @@ namespace refactor::kernel {
d->algo,
&workspaceSize));
}
if (d->extraPadding) {
workspaceSize = hardware::alignBytes(workspaceSize, 256);
}

// nvcc at c++11 doesn't support real move capture
auto routine = [d_ = std::move(d),
workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto const &d = *d_;
if (d.biasExpand) { (*(d.biasExpand))(res, workspace, inputs + 2, outputs); }
// fetch cudnn handle from resources
auto handle = res.fetchOrStore<CudnnContext>()->handle;
auto routine = [d, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
void const *x = inputs[0], *w = inputs[1];
if (d->extraPadding) {
auto extra = reinterpret_cast<uint8_t *>(workspace) + workspaceSize;
thrust::tabulate(thrust::device,
extra, extra + d->extraPadding->workspace(),
ExtraPaddingFunctor<float>{*d->extraPadding, x});
x = extra;
}
if (d->biasExpand) { (*(d->biasExpand))(res, workspace, inputs + 2, outputs); }
// build alpha/beta for double
union {
float f32[2];
double f64[2];
};
void *alpha, *beta;
if (d.f64) {
if (d->f64) {
f64[0] = 1;
f64[1] = d.biasExpand ? 1 : 0;
f64[1] = d->biasExpand ? 1 : 0;
alpha = f64;
beta = f64 + 1;
} else {
f32[0] = 1;
f32[1] = d.biasExpand ? 1 : 0;
f32[1] = d->biasExpand ? 1 : 0;
alpha = f32;
beta = f32 + 1;
}
CUDNN_ASSERT(cudnnConvolutionForward(
handle,
res.fetchOrStore<CudnnContext>()->handle,
alpha,
d.x, inputs[0],
d.w, inputs[1],
d.conv, d.algo,
d->x, x,
d->w, w,
d->conv, d->algo,
workspace, workspaceSize,
beta,
d.y, outputs[0]));
d->y, outputs[0]));
};
return {std::move(routine), workspaceSize};
return {std::move(routine), d->extraPadding ? workspaceSize + d->extraPadding->workspace() : workspaceSize};
}

}// namespace refactor::kernel
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/conv/cudnn_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace refactor::kernel {
wShape[4],
yShape[4],
dilation[2],
pad[2],
pad[4],
stride[2];
std::optional<ExpandInfo> biasExpand;
} info;
Expand Down

0 comments on commit f9d5015

Please sign in to comment.