Skip to content

Commit

Permalink
style(onnx): 整理 MatMul cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 14, 2023
1 parent 8c1b8fa commit 4bbf121
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 99 deletions.
158 changes: 81 additions & 77 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ namespace refactor::kernel {
: Kernel(), info(std::move(info_)) {}

auto K::build(MatMulInfo info) noexcept -> KernelBox {
if (!info.dataType.isCpuNumberic()) {
return nullptr;
}

return std::make_unique<K>(std::move(info));
return info.dataType.isCpuNumberic()
? std::make_unique<K>(std::move(info))
: nullptr;
}

auto K::typeId() noexcept -> size_t {
Expand All @@ -26,97 +24,103 @@ namespace refactor::kernel {
return "Performing MatMul using CPU";
}

template<class T>
struct MatMulCPUMetaData {
size_t M, K, N;
size_t strideA0, strideA1, strideB0, strideB1;
};
T alpha, beta;

/*
* 2D matrix multiplication: Y = a * A @ B + b * Y
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
template<typename T>
void matrixMultiply(T const *A, T const *B, T *Y,
T const alpha, T const beta,
const MatMulCPUMetaData md) {
// #pragma omp parallel for
for (size_t i = 0; i < md.M; i++) {
for (size_t j = 0; j < md.N; j++) {
T sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < md.K; k++) {
sum += A[i * md.strideA0 + k * md.strideA1] * B[k * md.strideB0 + j * md.strideB1];
/*
* 2D matrix multiplication: Y = a * A @ B + b * Y
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
void matrixMultiply(T const *A, T const *B, T *Y) const noexcept {
// #pragma omp parallel for
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
T sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < K; k++) {
sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1];
}
Y[i * N + j] = beta * Y[i * N + j] + alpha * sum;
}
Y[i * md.N + j] = beta * Y[i * md.N + j] + alpha * sum;
}
}
}
};

#define CASE(T) \
case DT::T: { \
using T_ = primitive<DT::T>::type; \
if (std::holds_alternative<Broadcaster>(info.broadcasterOrBatch)) { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
dim_t offset[2]; \
for (size_t i = 0; i < broadcaster.outputsCount; i++) { \
broadcaster.locate(i, offset); \
matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \
} \
}; \
} else { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
batch = std::get<size_t>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
for (size_t i = 0; i < batch; i++) { \
matrixMultiply(A + stepA * i, B + stepB * i, Y + stepY * i, alpha, beta, md); \
} \
}; \
} \
template<class T>
static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace {
MatMulCPUMetaData const md{
.M = info.m,
.K = info.k,
.N = info.n,
.strideA0 = info.transA ? 1 : info.k,
.strideA1 = info.transA ? info.m : 1,
.strideB0 = info.transB ? 1 : info.n,
.strideB1 = info.transB ? info.k : 1,
.alpha = static_cast<T>(info.alpha),
.beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
};

auto stepY = info.m * info.n,
stepA = info.m * info.k,
stepB = info.k * info.n;
auto biasEx = info.biasExpand
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine)
: std::nullopt;

if (std::holds_alternative<Broadcaster>(info.broadcasterOrBatch)) {
return [broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch),
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
dim_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i);
}
};
} else {
return [batch = std::get<size_t>(info.broadcasterOrBatch),
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(batch)) {
md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i);
}
};
}
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
MatMulCPUMetaData md;
md.M = info.m, md.K = info.k, md.N = info.n;
md.strideA0 = info.transA ? 1 : info.k;
md.strideA1 = info.transA ? info.m : 1;
md.strideB0 = info.transB ? 1 : info.n;
md.strideB1 = info.transB ? info.k : 1;
#define CASE(T) \
case DataType::T: \
return lowerTyped<primitive<DataType::T>::type>(info, res);

switch (info.dataType) {
CASE(F32);
CASE(F64);

CASE(U8);
CASE(I8);
CASE(U16);
CASE(U32);
CASE(U64);

CASE(I8);
CASE(I16);
CASE(I32);
CASE(I64);
CASE(F64);
CASE(U32);
CASE(U64);
default:
UNREACHABLE();
}
Expand Down
31 changes: 14 additions & 17 deletions src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,25 @@ namespace refactor::kernel {
auto d = std::make_shared<Descriptors>();

// clang-format off
cudnnActivationMode_t
mode = type == Ty::Relu ? CUDNN_ACTIVATION_RELU
: type == Ty::Sigmoid ? CUDNN_ACTIVATION_SIGMOID
: type == Ty::Tanh ? CUDNN_ACTIVATION_TANH
: UNREACHABLEX(cudnnActivationMode_t, "");
auto mode = type == Ty::Relu ? CUDNN_ACTIVATION_RELU
: type == Ty::Sigmoid ? CUDNN_ACTIVATION_SIGMOID
: type == Ty::Tanh ? CUDNN_ACTIVATION_TANH
: UNREACHABLEX(cudnnActivationMode_t, "");
// clang-format on

setCudnnTensor(d->tensor, dataType, slice(&size, 1));
CUDNN_ASSERT(cudnnSetActivationDescriptor(d->activation, mode, CUDNN_PROPAGATE_NAN, 0.0));
CUDNN_ASSERT(cudnnSetTensor4dDescriptor(d->tensor, CUDNN_TENSOR_NCHW, cudnnDataTypeConvert(dataType), 1, 1, 1, size));

res.fetchOrStore<CudnnContext>();
// nvcc at c++11 doesn't support real move capture
return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
// fetch cudnn handle from resources
auto handle = res.fetchOrStore<CudnnContext>()->handle;
// name inputs and outputs
auto x = inputs[0];
auto y = outputs[0];
// call cudnn activation
float alpha = 1, beta = 0;
CUDNN_ASSERT(cudnnActivationForward(handle, d->activation, &alpha, d->tensor, x, &beta, d->tensor, y));
};
return [d = std::move(d)]//
(Resources & res, void *, void const *const *inputs, void *const *outputs) {
float alpha = 1, beta = 0;
CUDNN_ASSERT(cudnnActivationForward(
res.fetchOrStore<CudnnContext>()->handle,
d->activation,
&alpha, d->tensor, inputs[0],
&beta, d->tensor, outputs[0]));
};
}

#endif
Expand Down
7 changes: 2 additions & 5 deletions src/04kernel/src/kernels/softmax/cudnn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ namespace refactor::kernel {
auto d = std::make_shared<Descriptors>(
static_cast<cudnnSoftmaxAlgorithm_t>(algo),
dataType != DataType::F64);
CUDNN_ASSERT(cudnnSetTensor4dDescriptor(
d->t,
CUDNN_TENSOR_NCHW,
cudnnDataTypeConvert(dataType),
pre, mid, post, 1));
int dims[]{pre, mid, post, 1};
setCudnnTensor(d->t, dataType, slice(dims, 4));

res.fetchOrStore<CudnnContext>();
return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
Expand Down

0 comments on commit 4bbf121

Please sign in to comment.