Skip to content

Commit

Permalink
feat(kernel): 实现 MatMulInteger cublas kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 18, 2023
1 parent 916fd3d commit c74b588
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/04kernel/src/collectors/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "kernel/collectors/mat_mul_integer.h"
#include "../../src/kernels/mat_mul_integer/cpu_kernel.hh"
#include "../../src/kernels/mat_mul_integer/cublas_kernel.hh"
#include "kernel/attributes/mat_mul_integer_info.h"

namespace refactor::kernel {
Expand All @@ -16,6 +17,9 @@ namespace refactor::kernel {
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = MatMulIntegerCublas::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace refactor::kernel {
}
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;

size_t workspace = 0;
Expand Down
28 changes: 28 additions & 0 deletions src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "cublas_kernel.hh"

namespace refactor::kernel {
using K = MatMulIntegerCublas;
using DT = DataType;

K::MatMulIntegerCublas(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(decltype(info) info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

return std::make_unique<K>(std::move(info));
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing MatMulInteger using CUBLAS";
}

}// namespace refactor::kernel
115 changes: 115 additions & 0 deletions src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "../../utilities/cuda/cublas_context.hh"
#include "cublas_kernel.hh"
#include <cublas_v2.h>
#include <thrust/execution_policy.h>
#include <thrust/tabulate.h>

namespace refactor::kernel {
using namespace runtime;
using namespace cublas;

template<class T> __device__ __forceinline__ static int8_t sub(T, T);
template<> __device__ __forceinline__ int8_t sub<int8_t>(int8_t a, int8_t b) { return a - b; }
template<> __device__ __forceinline__ int8_t sub<uint8_t>(uint8_t a, uint8_t b) { return static_cast<int8_t>(static_cast<int16_t>(a) - static_cast<int16_t>(b)); }

template<class T>
struct MatMulIntegerZPFunctor {
dim_t groupSize;
T const *src, *zp;

__device__ int8_t operator()(size_t i) const noexcept {
return sub(src[i], zp[i / groupSize]);
}
};

template<class T>
static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src, void const *zp) {
thrust::tabulate(
thrust::device,
dst, dst + meta.groupCount * meta.groupSize,
MatMulIntegerZPFunctor<T>{
.groupSize = meta.groupSize,
.src = reinterpret_cast<T const *>(src),
.zp = reinterpret_cast<T const *>(zp),
});
}

auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace {

size_t workspace = 0;
if (info.a.withZeroPoint) {
workspace += info.a.groupCount * info.a.groupSize;
}
if (info.b.withZeroPoint) {
workspace += info.b.groupCount * info.b.groupSize;
}

auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto workspacePtr = reinterpret_cast<int8_t *>(workspace);
auto a = reinterpret_cast<int8_t const *>(inputs[0]),
b = reinterpret_cast<int8_t const *>(inputs[1]);
auto y = reinterpret_cast<int32_t *>(outputs[0]);

if (auto meta = info.a; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, a, inputs[2]);
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, a, inputs[2]);
}
a = workspacePtr;
workspacePtr += meta.groupCount * meta.groupSize;
}
if (auto meta = info.b; meta.withZeroPoint) {
if (meta.signed_) {
applyZeroPoint<int8_t>(meta, workspacePtr, b, inputs[3]);
} else {
applyZeroPoint<uint8_t>(meta, workspacePtr, b, inputs[3]);
}
b = workspacePtr;
}

int32_t alpha = 1, beta = 0;
auto m = info.m,
n = info.n,
k = info.k;
auto strideY = m * n,
strideA = m * k,
strideB = k * n;
auto lda = info.k,
ldb = info.n;
if (info.broadcaster.needBroadcast()) {

uint32_t offset[2];
for (auto i : range0_(info.broadcaster.outputsCount)) {
info.broadcaster.locate(i, offset);
cublasGemmEx(
res.fetchOrStore<CublasContext>()->handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
b + strideB * offset[1], CUDA_R_8I, ldb,
a + strideA * offset[0], CUDA_R_8I, lda,
&beta, y + strideY * i, CUDA_R_32I,
n, CUDA_R_32I,
CUBLAS_GEMM_DEFAULT);
}
} else {

cublasGemmStridedBatchedEx(
res.fetchOrStore<CublasContext>()->handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
b, CUDA_R_8I, ldb, strideB,
a, CUDA_R_8I, lda, strideA,
&beta, y, CUDA_R_32I,
n, m * n, info.broadcaster.outputsCount, CUDA_R_32I,
CUBLAS_GEMM_DEFAULT);
}
};

res.fetchOrStore<CublasContext>();
return {std::move(routine), workspace};
}

}// namespace refactor::kernel
26 changes: 26 additions & 0 deletions src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH
#define KERNEL_MATMUL_CUBLAS_KERNEL_HH

#include "kernel/attributes/mat_mul_integer_info.h"
#include "kernel/kernel.h"

namespace refactor::kernel {

struct MatMulIntegerCublas final : public Kernel {
MatMulIntegerInfo info;

explicit MatMulIntegerCublas(decltype(info)) noexcept;

static KernelBox build(decltype(info)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const noexcept final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_MATMUL_CUBLAS_KERNEL_HH

0 comments on commit c74b588

Please sign in to comment.