Skip to content

Commit

Permalink
feat(kernel): 添加逐张量量化的 cpu kernel
Browse files Browse the repository at this point in the history
fix(kernel): 改正 mat mul integer

Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 18, 2023
1 parent 3d14bf0 commit 103254b
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 54 deletions.
8 changes: 5 additions & 3 deletions src/04kernel/include/kernel/attributes/mat_mul_integer_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ namespace refactor::kernel {

struct MatMulIntegerInfo {
struct Input {
bool withZeroPoint;
bool signed_;
dim_t groupCount, groupSize;
bool
withZeroPoint,
signed_,
scalar;

Input(TensorRefs const &, size_t i) noexcept;
};
Expand All @@ -19,6 +20,7 @@ namespace refactor::kernel {
Broadcaster broadcaster;

explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept;
dim_t batch() const noexcept;
};

}// namespace refactor::kernel
Expand Down
30 changes: 20 additions & 10 deletions src/04kernel/src/attributes/mat_mul_integer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,41 @@

namespace refactor::kernel {

#define A (inputs[0].get().shape)
#define B (inputs[1].get().shape)

MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept
: withZeroPoint(false),
signed_(true),
groupCount(1),
groupSize(1) {
scalar(true) {
if (inputs.size() > i + 2) {
auto const &t = inputs[i + 2].get();
if (withZeroPoint = t.rank() != 0 || !t.data || t.data->get<uint8_t>() != 0) {
signed_ = t.dataType == DataType::I8;
groupCount = t.elementsSize();
groupSize = inputs[i].get().elementsSize() / groupCount;
auto size = t.elementsSize();
if (t.data) {
auto data = slice(t.data->get<uint8_t>(), size);
if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) {
return;
}
}
withZeroPoint = true;
signed_ = t.dataType == DataType::I8;
scalar = size == 1;
}
}

MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept
: a(inputs, 0),
b(inputs, 1),
#define A (inputs[0].get().shape)
#define B (inputs[1].get().shape)
m(A.rbegin()[1]),
k(A.rbegin()[0]),
n(B.rbegin()[0]),
broadcaster({slice(A.data(), A.size() - 2),
slice(B.data(), B.size() - 2)}) {}
slice(B.data(), B.size() - 2)}) {
}
#undef A
#undef B

dim_t MatMulIntegerInfo::batch() const noexcept {
return broadcaster.outputsCount;
}

}// namespace refactor::kernel
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/dynamic_quantize_linear.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/collectors/dynamic_quantize_linear.h"
#include "../kernels/dynamic_quantize_linear/cpu_kernel.hh"

namespace refactor::kernel {

Expand All @@ -8,9 +9,14 @@ namespace refactor::kernel {

std::vector<KernelBox>
DynamicQuantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto size = inputs[0].get().elementsSize();

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = DynamicQuantizeLinearCpu::build(size); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
break;
Expand Down
64 changes: 64 additions & 0 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "cpu_kernel.hh"
#include <execution>
#include <numeric>

namespace refactor::kernel {
using K = DynamicQuantizeLinearCpu;

K::DynamicQuantizeLinearCpu(decltype(size) size_) noexcept
: Kernel(), size(size_) {}

auto K::build(decltype(size) size) noexcept -> KernelBox {
return std::make_unique<K>(size);
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing dynamic quantize linear using CPU";
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;
return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) {
using TI = float;
using TO = uint8_t;

constexpr static auto
ZERO = static_cast<TI>(0),
_MIN = std::numeric_limits<TI>::min(),
_MAX = std::numeric_limits<TI>::max(),
QMIN = static_cast<TI>(std::numeric_limits<TO>::min()),
QMAX = static_cast<TI>(std::numeric_limits<TO>::max()),
QLEN = QMAX - QMIN;

auto x = reinterpret_cast<TI const *>(inputs[0]);
auto [min, max] = std::accumulate(
x, x + size,
std::pair{_MAX, _MIN},
[](auto acc, auto it) {
auto [min, max] = acc;
return std::pair{
std::min(min, it),
std::max(max, it),
};
});
auto len = std::max(ZERO, max) - std::min(ZERO, min);
auto scale = len / QLEN;
auto zp = static_cast<TO>(std::round(QMIN - min * QLEN / len));

std::transform(
std::execution::par_unseq,
x, x + size,
reinterpret_cast<TO *>(outputs[0]),
[=](auto it) { return static_cast<TO>(std::round(it / scale) + zp); });
*reinterpret_cast<TI *>(outputs[1]) = scale;
*reinterpret_cast<TO *>(outputs[2]) = zp;
};
}

}// namespace refactor::kernel
23 changes: 23 additions & 0 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH
#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH

#include "kernel/kernel.h"

namespace refactor::kernel {

struct DynamicQuantizeLinearCpu final : public Kernel {
size_t size;

explicit DynamicQuantizeLinearCpu(decltype(size)) noexcept;

static KernelBox build(decltype(size)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_SOFTMAX_CPU_KERNEL_HH
23 changes: 23 additions & 0 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "cuda_kernel.hh"

namespace refactor::kernel {
using K = DynamicQuantizeLinearCuda;

K::DynamicQuantizeLinearCuda(decltype(size) size_) noexcept
: Kernel(), size(size_) {}

auto K::build(decltype(size) size) noexcept -> KernelBox {
return std::make_unique<K>(size);
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing dynamic quantize linear using Nvidia GPU";
}

}// namespace refactor::kernel
16 changes: 16 additions & 0 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "cuda_kernel.hh"
#include <cub/cub.cuh>

namespace refactor::kernel {
using K = DynamicQuantizeLinearCuda;

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;
using TI = float;
using TO = uint8_t;

return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) {
};
}

}// namespace refactor::kernel
25 changes: 25 additions & 0 deletions src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH
#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH

#include "kernel/kernel.h"

namespace refactor::kernel {

struct DynamicQuantizeLinearCuda final : public Kernel {
size_t size;

explicit DynamicQuantizeLinearCuda(decltype(size)) noexcept;

static KernelBox build(decltype(size)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const noexcept final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH
79 changes: 62 additions & 17 deletions src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cpu_kernel.hh"
#include "../mat_mul_common/cpu_template.hpp"
#include <execution>

namespace refactor::kernel {
using K = MatMulIntegerCpu;
Expand Down Expand Up @@ -27,25 +28,49 @@ namespace refactor::kernel {
template<> int8_t sub<uint8_t>(uint8_t a, uint8_t b) { return static_cast<int8_t>(static_cast<int16_t>(a) - static_cast<int16_t>(b)); }

template<class T>
static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src_, void const *zp_) {
static void applyZeroPointScalar(
size_t size, int8_t *dst, void const *src_, void const *zp_) {

auto src = reinterpret_cast<T const *>(src_);
auto zp = *reinterpret_cast<T const *>(zp_);
std::transform(std::execution::par_unseq,
src, src + size,
dst, [zp](auto x) { return sub(x, zp); });
}
template<class T>
static void applyZeroPointA(
dim_t b, dim_t m, dim_t n,
int8_t *dst, void const *src_, void const *zp_) {

auto src = reinterpret_cast<T const *>(src_),
zp = reinterpret_cast<T const *>(zp_);
for (auto i : range0_(meta.groupCount)) {
for (auto j : range0_(meta.groupSize)) {
dst[meta.groupSize * i + j] = sub(src[meta.groupSize * i + j], zp[i]);
}
}
for (auto i : range0_(b))
for (auto j : range0_(m))
for (auto k : range0_(n))
dst[i * m * n + j * n + k] = sub(src[i * m * n + j * n + k], zp[i * m + j]);
}
template<class T>
static void applyZeroPointB(
dim_t b, dim_t m, dim_t n,
int8_t *dst, void const *src_, void const *zp_) {

auto src = reinterpret_cast<T const *>(src_),
zp = reinterpret_cast<T const *>(zp_);
for (auto i : range0_(b))
for (auto j : range0_(m))
for (auto k : range0_(n))
dst[i * m * n + j * n + k] = sub(src[i * m * n + j * n + k], zp[i * n + k]);
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;

size_t workspace = 0;
if (info.a.withZeroPoint) {
workspace += info.a.groupCount * info.a.groupSize;
workspace += info.batch() * info.m * info.k;
}
if (info.b.withZeroPoint) {
workspace += info.b.groupCount * info.b.groupSize;
workspace += info.batch() * info.k * info.n;
}

auto routine = [info = info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
Expand All @@ -55,19 +80,39 @@ namespace refactor::kernel {
auto y = reinterpret_cast<int32_t *>(outputs[0]);

if (auto meta = info.a; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, a, inputs[2]);
auto size = info.batch() * info.m * info.k;
auto zp = inputs[2];
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, workspacePtr, a, zp);
} else {
applyZeroPointScalar<uint8_t>(size, workspacePtr, a, zp);
}
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, a, inputs[2]);
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.m, info.k, workspacePtr, a, zp);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.m, info.k, workspacePtr, a, zp);
}
}
a = workspacePtr;
workspacePtr += meta.groupCount * meta.groupSize;
workspacePtr += size;
}
if (auto meta = info.b; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, b, inputs[3]);
auto size = info.batch() * info.k * info.n;
auto zp = inputs[3];
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, workspacePtr, b, zp);
} else {
applyZeroPointScalar<uint8_t>(size, workspacePtr, b, zp);
}
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, b, inputs[3]);
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.k, info.n, workspacePtr, b, zp);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.k, info.n, workspacePtr, b, zp);
}
}
b = workspacePtr;
}
Expand All @@ -89,12 +134,12 @@ namespace refactor::kernel {

if (info.broadcaster.needBroadcast()) {
dim_t offset[2];
for (auto i : range0_(info.broadcaster.outputsCount)) {
for (auto i : range0_(info.batch())) {
info.broadcaster.locate(i, offset);
md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i);
}
} else {
for (auto i : range0_(info.broadcaster.outputsCount)) {
for (auto i : range0_(info.batch())) {
md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i);
}
}
Expand Down
Loading

0 comments on commit 103254b

Please sign in to comment.