Skip to content

Commit

Permalink
feat(kernel): 实现反量化算子
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 19, 2023
1 parent ee14c96 commit 2e75d38
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/04kernel/src/collectors/dequantize_linear.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "kernel/collectors/dequantize_linear.h"
#include "../kernels/dequantize_linear/cpu_kernel.hh"
#include "../kernels/dequantize_linear/cuda_kernel.hh"

namespace refactor::kernel {

Expand All @@ -8,11 +10,18 @@ namespace refactor::kernel {

std::vector<KernelBox>
DequantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &output = outputs[0];
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = DequantizeLinearCpu::build(inputs, output); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = DequantizeLinearCuda::build(inputs, output); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
Expand Down
4 changes: 4 additions & 0 deletions src/04kernel/src/collectors/dynamic_quantize_linear.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "kernel/collectors/dynamic_quantize_linear.h"
#include "../kernels/dynamic_quantize_linear/cpu_kernel.hh"
#include "../kernels/dynamic_quantize_linear/cuda_kernel.hh"

namespace refactor::kernel {

Expand All @@ -19,6 +20,9 @@ namespace refactor::kernel {
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = DynamicQuantizeLinearCuda::build(size); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
Expand Down
76 changes: 76 additions & 0 deletions src/04kernel/src/kernels/dequantize_linear/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "cpu_kernel.hh"
#include <execution>
#include <numeric>

namespace refactor::kernel {
using K = DequantizeLinearCpu;

K::DequantizeLinearCpu(
decltype(from) from_,
decltype(size) size_,
decltype(withZeroPoint) withZeroPoint_) noexcept
: Kernel(),
from(from_),
size(size_),
withZeroPoint(withZeroPoint_) {}

auto K::build(TensorRefs const &inputs, Tensor const &output) noexcept -> KernelBox {
if (inputs[1].get().elementsSize() != 1) {
return nullptr;
}
if (output.dataType != DataType::F32) {
return nullptr;
}
return std::make_unique<K>(
inputs[0].get().dataType,
inputs[0].get().elementsSize(),
inputs.size() > 2);
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing dequantize linear using CPU";
}

template<class TI, class TO>
auto lowerTyped(size_t size, bool withZeroPoint) noexcept -> RoutineWorkspace {

return [size, withZeroPoint]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto x = reinterpret_cast<TI const *>(inputs[0]);
auto scale = *reinterpret_cast<TO const *>(inputs[1]);
auto zp = withZeroPoint ? *reinterpret_cast<TI const *>(inputs[2]) : 0;
auto y = reinterpret_cast<TO *>(outputs[0]);
std::transform(
std::execution::par_unseq,
x, x + size,
y,
[scale, zp](TI x) {
return static_cast<TO>(x - zp) * scale;
});
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
switch (from) {
case DataType::U8:
return lowerTyped<uint8_t, float>(size, withZeroPoint);
case DataType::U16:
return lowerTyped<uint16_t, float>(size, withZeroPoint);
case DataType::I8:
return lowerTyped<int8_t, float>(size, withZeroPoint);
case DataType::I16:
return lowerTyped<int16_t, float>(size, withZeroPoint);
case DataType::I32:
return lowerTyped<int32_t, float>(size, withZeroPoint);
default:
UNREACHABLE();
}
}

}// namespace refactor::kernel
29 changes: 29 additions & 0 deletions src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH
#define KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct DequantizeLinearCpu final : public Kernel {
DataType from;
size_t size;
bool withZeroPoint;

DequantizeLinearCpu(
decltype(from),
decltype(size),
decltype(withZeroPoint)) noexcept;

static KernelBox build(TensorRefs const &, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH
92 changes: 92 additions & 0 deletions src/04kernel/src/kernels/dequantize_linear/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "cuda_kernel.hh"

#ifdef USE_CUDA
#include "../../generator/nvrtc_repo.h"
#include "kernel/cuda/threads_distributer.cuh"
#endif

namespace refactor::kernel {
using K = DequantizeLinearCuda;

K::DequantizeLinearCuda(
decltype(from) from_,
decltype(to) to_,
decltype(size) size_,
decltype(withZeroPoint) withZeroPoint_) noexcept
: Kernel(),
from(from_),
to(to_),
size(size_),
withZeroPoint(withZeroPoint_) {}

auto K::build(TensorRefs const &inputs, Tensor const &output) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

auto const &x = inputs[0].get();
if (inputs[1].get().elementsSize() != 1) {
return nullptr;
}
return std::make_unique<K>(
x.dataType,
output.dataType,
x.elementsSize(),
inputs.size() > 2);
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing dequantize linear using Nvidia GPU";
}

#ifdef USE_CUDA

constexpr static const char *TEMPLATE = R"~(
extern "C" __global__ void kernel(
{0:} *__restrict__ y,
{1:} const *__restrict__ x,
{0:} const *__restrict__ scale_,
{1:} const *__restrict__ zp_,
size_t n
) {{
auto zp = zp_ ? *zp_ : static_cast<{1:}>(0);
auto scale = *scale_;
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {{
y[tid] = static_cast<{0:}>(x[tid] - zp) * scale;
}}
}}
)~";

auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace runtime;

auto name = fmt::format("DequantizeLinear{}->{}", from.name(), to.name());
auto code = fmt::format(TEMPLATE, nvrtc::dataType(to), nvrtc::dataType(from));
return [withZeroPoint = withZeroPoint,
params = cuda::ThreadsDistributer()(size),
h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto y = outputs[0];
auto x = inputs[0],
scale = inputs[1],
zp = withZeroPoint ? inputs[2] : nullptr;
auto n = params.n;
void *args[]{&y, &x, &scale, &zp, &n};
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}

#endif

}// namespace refactor::kernel
32 changes: 32 additions & 0 deletions src/04kernel/src/kernels/dequantize_linear/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH
#define KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct DequantizeLinearCuda final : public Kernel {
DataType from, to;
size_t size;
bool withZeroPoint;

DequantizeLinearCuda(
decltype(from),
decltype(to),
decltype(size),
decltype(withZeroPoint)) noexcept;

static KernelBox build(TensorRefs const &, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ namespace refactor::kernel {

}// namespace refactor::kernel

#endif// KERNEL_SOFTMAX_CPU_KERNEL_HH
#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH
31 changes: 31 additions & 0 deletions src/04kernel/test/kernels/dequantize_linear/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "../../../src/kernels/dequantize_linear/cpu_kernel.hh"
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;

TEST(kernel, DequantizeLinearCpu) {
// build routine
auto x = Tensor::share(DataType::U8, {4});
auto scale = Tensor::share(DataType::F32, {});
auto zeroPoint = Tensor::share(DataType::U8, {});
auto y = Tensor::share(DataType::F32, {4});
auto kernel = DequantizeLinearCpu::build({*x, *scale, *zeroPoint}, *y);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<uint8_t> xData{0, 3, 128, 255};
float scale_ = 2;
uint8_t zp_ = 128;
std::vector<float> yData(xData.size());
// inference
{
void const *inputs[]{xData.data(), &scale_, &zp_};
void *outputs[]{yData.data()};
routine(res, nullptr, inputs, outputs);
}
// check
ASSERT_EQ(yData, (decltype(yData){-256, -250, 0, 254}));
}
57 changes: 57 additions & 0 deletions src/04kernel/test/kernels/dequantize_linear/test_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifdef USE_CUDA

#include "../../../src/kernels/dequantize_linear/cpu_kernel.hh"
#include "../../../src/kernels/dequantize_linear/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, DequantizeLinearCuda) {
// build routine
auto x = Tensor::share(DataType::U8, {4});
auto scale = Tensor::share(DataType::F32, {});
auto zeroPoint = Tensor::share(DataType::U8, {});
auto y = Tensor::share(DataType::F32, {4});
auto kernel = DequantizeLinearCuda::build({*x, *scale, *zeroPoint}, *y),
kCpu = DequantizeLinearCpu::build({*x, *scale, *zeroPoint}, *y);
ASSERT_TRUE(kernel && kCpu);
auto res = runtime::Resources();
auto [routine, workspaceSize] = kernel->lower(res);
auto rCpu = kCpu->lower(res).routine;
// malloc
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
auto xGpu = dev.malloc(x->bytesSize()),
scaleGpu = dev.malloc(sizeof(float)),
zpGpu = dev.malloc(sizeof(uint8_t)),
yGpu = dev.malloc(y->bytesSize());
// put input data
std::vector<uint8_t> xData{0, 3, 128, 255};
float scale_ = 2;
uint8_t zp_ = 128;
std::vector<float> yData(xData.size());
xGpu->copyFromHost(xData.data());
scaleGpu->copyFromHost(&scale_);
zpGpu->copyFromHost(&zp_);
// inference
{
void const *inputs[]{*xGpu, *scaleGpu, *zpGpu};
void *outputs[]{*yGpu};
routine(res, nullptr, inputs, outputs);
}
{
void const *inputs[]{xData.data(), &scale_, &zp_};
void *outputs[]{yData.data()};
rCpu(res, nullptr, inputs, outputs);
}
// check
{
std::vector<float> result(yData.size());
yGpu->copyToHost(result.data());
EXPECT_EQ(result, yData);
}
}

#endif

0 comments on commit 2e75d38

Please sign in to comment.