Skip to content

Commit

Permalink
feat(kernel): 实现 MatMulInteger cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 15, 2023
1 parent a8959f7 commit 050c0b1
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ namespace refactor::kernel {

struct MatMulIntegerInfo {
struct Input {
bool signed_;
bool withZeroPoint;
bool signed_;
dim_t groupCount, groupSize;

Input(TensorRefs const &, size_t i) noexcept;
};
Expand Down
74 changes: 31 additions & 43 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,56 +27,44 @@ namespace refactor::kernel {

template<class T>
static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace {
MatMulCPUMetaData const md{
.M = info.m,
.K = info.k,
.N = info.n,
.strideA0 = info.transA ? 1 : info.k,
.strideA1 = info.transA ? info.m : 1,
.strideB0 = info.transB ? 1 : info.n,
.strideB1 = info.transB ? info.k : 1,
.alpha = static_cast<T>(info.alpha),
.beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
};

auto stepY = info.m * info.n,
stepA = info.m * info.k,
stepB = info.k * info.n;
auto biasEx = info.biasExpand
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine)
: std::nullopt;

if (info.broadcaster.needBroadcast()) {
return [broadcaster = info.broadcaster,
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }
return [info = info, biasEx](runtime::Resources &res, void *, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, nullptr, inputs + 2, outputs); }

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
dim_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i);
}
};
} else {
return [batch = info.broadcaster.outputsCount,
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }
MatMulCPUMetaData const md{
.M = info.m,
.K = info.k,
.N = info.n,
.strideA0 = info.transA ? 1 : info.k,
.strideA1 = info.transA ? info.m : 1,
.strideB0 = info.transB ? 1 : info.n,
.strideB1 = info.transB ? info.k : 1,
.alpha = static_cast<T>(info.alpha),
.beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
};
auto const stepY = info.m * info.n,
stepA = info.m * info.k,
stepB = info.k * info.n;

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(batch)) {
md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i);
}
};
}
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
if (info.broadcaster.needBroadcast()) {
dim_t offset[2];
for (auto i : range0_(info.broadcaster.outputsCount)) {
info.broadcaster.locate(i, offset);
md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i);
}
} else {
for (auto i : range0_(info.broadcaster.outputsCount)) {
md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i);
}
}
};
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
Expand Down
18 changes: 10 additions & 8 deletions src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,28 @@

namespace refactor::kernel {

template<class T>
template<class TO, class TI>
struct MatMulCPUMetaData {
size_t M, K, N;
size_t strideA0, strideA1, strideB0, strideB1;
T alpha, beta;
size_t M, K, N,
strideA0, strideA1,
strideB0, strideB1;
TI alpha;
TO beta;

/*
* 2D matrix multiplication: Y = a * A @ B + b * Y
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
void matrixMultiply(T const *A, T const *B, T *Y) const noexcept {
void matrixMultiply(TI const *a, TI const *b, TO *y) const noexcept {
// #pragma omp parallel for
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
T sum = 0;
TO sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < K; k++) {
sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1];
sum += static_cast<TO>(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]);
}
Y[i * N + j] = beta * Y[i * N + j] + alpha * sum;
y[i * N + j] = beta * y[i * N + j] + alpha * sum;
}
}
}
Expand Down
81 changes: 79 additions & 2 deletions src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,85 @@ namespace refactor::kernel {
return "Performing MatMulInteger using CPU";
}

auto K::lower(Resources &res) const -> RoutineWorkspace {
TODO("");
template<class T> static int8_t sub(T, T);
template<> int8_t sub<int8_t>(int8_t a, int8_t b) { return a - b; }
template<> int8_t sub<uint8_t>(uint8_t a, uint8_t b) { return static_cast<int8_t>(static_cast<int16_t>(a) - static_cast<int16_t>(b)); }

template<class T>
static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src_, void const *zp_) {
auto src = reinterpret_cast<T const *>(src_),
zp = reinterpret_cast<T const *>(zp_);
for (auto i : range0_(meta.groupCount)) {
for (auto j : range0_(meta.groupSize)) {
dst[meta.groupSize * i + j] = sub(src[meta.groupSize * i + j], zp[i]);
}
}
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
using namespace runtime;

size_t workspace = 0;
if (info.a.withZeroPoint) {
workspace += info.a.groupCount * info.a.groupSize;
}
if (info.b.withZeroPoint) {
workspace += info.b.groupCount * info.b.groupSize;
}

auto routine = [info = info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto workspacePtr = reinterpret_cast<int8_t *>(workspace);
auto a = reinterpret_cast<int8_t const *>(inputs[0]),
b = reinterpret_cast<int8_t const *>(inputs[1]);
auto y = reinterpret_cast<int32_t *>(outputs[0]);

if (auto meta = info.a; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, a, inputs[2]);
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, a, inputs[2]);
}
a = workspacePtr;
workspacePtr += meta.groupCount * meta.groupSize;
}
if (auto meta = info.b; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, b, inputs[3]);
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, b, inputs[3]);
}
b = workspacePtr;
}

MatMulCPUMetaData<int32_t, int8_t> const md{
.M = info.m,
.K = info.k,
.N = info.n,
.strideA0 = info.k,
.strideA1 = 1,
.strideB0 = info.n,
.strideB1 = 1,
.alpha = 1,
.beta = 0,
};
auto const stepY = info.m * info.n,
stepA = info.m * info.k,
stepB = info.k * info.n;

if (info.broadcaster.needBroadcast()) {
dim_t offset[2];
for (auto i : range0_(info.broadcaster.outputsCount)) {
info.broadcaster.locate(i, offset);
md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i);
}
} else {
for (auto i : range0_(info.broadcaster.outputsCount)) {
md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i);
}
}
};

return {std::move(routine), workspace};
};

}// namespace refactor::kernel
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace refactor::kernel {
size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;

RoutineWorkspace lower(Resources &) const final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel
Expand Down

0 comments on commit 050c0b1

Please sign in to comment.