Skip to content

Commit

Permalink
refactor(kernel): 正确设置 f64 的 alpha beta
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 16, 2023
1 parent f911439 commit 2861890
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
15 changes: 8 additions & 7 deletions src/04kernel/src/kernels/batch_normalization/cudnn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace refactor::kernel {
// RAII for closure
struct Descriptors {
cudnnTensorDescriptor_t x, param;
bool f64;

Descriptors() : x(nullptr), param(nullptr) {
CUDNN_ASSERT(cudnnCreateTensorDescriptor(&x));
Expand All @@ -26,6 +27,7 @@ namespace refactor::kernel {
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>();
d->f64 = info.dtParam == DT::F64;

int strideAx[4]{0, 0, 0, 1}, // to calculate
dimAp[4]{1, info.dimAx[1], 1, 1}, // 1xCx1x1
Expand All @@ -42,7 +44,6 @@ namespace refactor::kernel {

// nvcc at c++11 doesn't support real move capture
return [d = std::move(d),
param32 = info.dtParam == DT::F32,
epsilon = info.epsilon](Resources &res, void const **inputs, void **outputs) {
// fetch cudnn handle from resources
auto handle = res.fetchOrStore<CudnnContext>()->handle;
Expand All @@ -59,16 +60,16 @@ namespace refactor::kernel {
double f64[2];
};
void *alpha, *beta;
if (param32) {
f32[0] = 1;
f32[1] = 0;
alpha = f32;
beta = f32 + 1;
} else {
if (d->f64) {
f64[0] = 1;
f64[1] = 0;
alpha = f64;
beta = f64 + 1;
} else {
f32[0] = 1;
f32[1] = 0;
alpha = f32;
beta = f32 + 1;
}
CUDNN_ASSERT(cudnnBatchNormalizationForwardInference(
handle, CUDNN_BATCHNORM_SPATIAL, alpha, beta,
Expand Down
27 changes: 23 additions & 4 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace refactor::kernel {
cudnnConvolutionDescriptor_t conv;
cudnnConvolutionFwdAlgo_t algo;
size_t workspaceSize;
bool f64;

Descriptors() : workspaceSize(0) {
CUDNN_ASSERT(cudnnCreateTensorDescriptor(&x));
Expand All @@ -33,6 +34,7 @@ namespace refactor::kernel {
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>();
d->f64 = info.dt == DataType::F64;

auto cudnnDataType = cudnnDataTypeConvert(info.dt);
auto xs = info.xShape, ys = info.yShape, ws = info.wShape;
Expand All @@ -52,6 +54,8 @@ namespace refactor::kernel {
d->x, d->w, d->conv, d->y,
1, &returnedAlgoCount, &perfResults));
ASSERT(returnedAlgoCount == 1, "returnedAlgoCount != 1");
// for high accuracy, use this algo only
// d->algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
d->algo = perfResults.algo;
CUDNN_ASSERT(cudnnGetConvolutionForwardWorkspaceSize(
handle,
Expand All @@ -65,16 +69,31 @@ namespace refactor::kernel {
// fetch cudnn handle from resources
auto handle = res.fetchOrStore<CudnnContext>()->handle;
auto workspace = ForeignBlob::share(res.fetch<MemManager>()->manager, d.workspaceSize);
// TODO? build alpha/beta for double
float alpha = 1, beta = 0;
// build alpha/beta for double
union {
float f32[2];
double f64[2];
};
void *alpha, *beta;
if (d.f64) {
f64[0] = 1;
f64[1] = 0;
alpha = f64;
beta = f64 + 1;
} else {
f32[0] = 1;
f32[1] = 0;
alpha = f32;
beta = f32 + 1;
}
CUDNN_ASSERT(cudnnConvolutionForward(
handle,
&alpha,
alpha,
d.x, inputs[0],
d.w, inputs[1],
d.conv, d.algo,
*workspace, d.workspaceSize,
&beta,
beta,
d.y, outputs[0]));
};
}
Expand Down
6 changes: 1 addition & 5 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ namespace refactor::kernel {
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res))
: std::nullopt,
broadcaster = info.broadcaster](Resources &res, void const **inputs, void **outputs) {
if (biasEx) {
void const *inputs_[]{inputs[2]};
void *outputs_[]{outputs[0]};
(*biasEx)(res, inputs_, outputs_);
}
if (biasEx) { (*biasEx)(res, inputs + 2, outputs); }

auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
Expand Down

0 comments on commit 2861890

Please sign in to comment.