From c74b5885c2ed262dc535e26cdae366e3c5ebe01f Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 09:51:23 +0800 Subject: [PATCH] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0=20MatMulInte?= =?UTF-8?q?ger=20cublas=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/collectors/mat_mul_integer.cc | 4 + .../src/kernels/mat_mul_integer/cpu_kernel.cc | 2 +- .../kernels/mat_mul_integer/cublas_kernel.cc | 28 +++++ .../kernels/mat_mul_integer/cublas_kernel.cu | 115 ++++++++++++++++++ .../kernels/mat_mul_integer/cublas_kernel.hh | 26 ++++ 5 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh diff --git a/src/04kernel/src/collectors/mat_mul_integer.cc b/src/04kernel/src/collectors/mat_mul_integer.cc index 123f0bee..c0124de9 100644 --- a/src/04kernel/src/collectors/mat_mul_integer.cc +++ b/src/04kernel/src/collectors/mat_mul_integer.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/mat_mul_integer.h" #include "../../src/kernels/mat_mul_integer/cpu_kernel.hh" +#include "../../src/kernels/mat_mul_integer/cublas_kernel.hh" #include "kernel/attributes/mat_mul_integer_info.h" namespace refactor::kernel { @@ -16,6 +17,9 @@ namespace refactor::kernel { } break; case decltype(_target)::Nvidia: + if (auto ptr = MatMulIntegerCublas::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } break; default: UNREACHABLEX(void, "Unknown target"); diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc index 0cf31a95..751fb7c0 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -37,7 +37,7 @@ namespace refactor::kernel { } } - auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; size_t workspace = 0; diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc new file mode 100644 index 00000000..d1eeb607 --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc @@ -0,0 +1,28 @@ +#include "cublas_kernel.hh" + +namespace refactor::kernel { + using K = MatMulIntegerCublas; + using DT = DataType; + + K::MatMulIntegerCublas(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(decltype(info) info) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + return std::make_unique(std::move(info)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing MatMulInteger using CUBLAS"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu new file mode 100644 index 00000000..94d3aeb0 --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -0,0 +1,115 @@ +#include "../../utilities/cuda/cublas_context.hh" +#include "cublas_kernel.hh" +#include +#include +#include + +namespace refactor::kernel { + using namespace runtime; + using namespace cublas; + + template __device__ __forceinline__ static int8_t sub(T, T); + template<> __device__ __forceinline__ int8_t sub(int8_t a, int8_t b) { return a - b; } + template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } + + template + struct MatMulIntegerZPFunctor { + dim_t groupSize; + T const *src, *zp; + + __device__ int8_t operator()(size_t i) const noexcept { + return sub(src[i], zp[i / groupSize]); + } + }; + + template + static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src, void const *zp) { + thrust::tabulate( + thrust::device, + dst, dst + meta.groupCount * meta.groupSize, + MatMulIntegerZPFunctor{ + .groupSize = meta.groupSize, + .src = reinterpret_cast(src), + .zp = reinterpret_cast(zp), + }); + } + + auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace { + + size_t workspace = 0; + if (info.a.withZeroPoint) { + workspace += info.a.groupCount * info.a.groupSize; + } + if (info.b.withZeroPoint) { + workspace += info.b.groupCount * info.b.groupSize; + } + + auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto workspacePtr = reinterpret_cast(workspace); + auto a = reinterpret_cast(inputs[0]), + b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + + if (auto meta = info.a; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } else { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } + a = workspacePtr; + workspacePtr += meta.groupCount * meta.groupSize; + } + if (auto meta = info.b; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } else { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } + b = workspacePtr; + } + + int32_t alpha = 1, beta = 0; + auto m = info.m, + n = info.n, + k = info.k; + auto strideY = m * n, + strideA = m * k, + strideB = k * n; + auto lda = info.k, + ldb = info.n; + if (info.broadcaster.needBroadcast()) { + + uint32_t offset[2]; + for (auto i : range0_(info.broadcaster.outputsCount)) { + info.broadcaster.locate(i, offset); + cublasGemmEx( + res.fetchOrStore()->handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + b + strideB * offset[1], CUDA_R_8I, ldb, + a + strideA * offset[0], CUDA_R_8I, lda, + &beta, y + strideY * i, CUDA_R_32I, + n, CUDA_R_32I, + CUBLAS_GEMM_DEFAULT); + } + } else { + + cublasGemmStridedBatchedEx( + res.fetchOrStore()->handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + b, CUDA_R_8I, ldb, strideB, + a, CUDA_R_8I, lda, strideA, + &beta, y, CUDA_R_32I, + n, m * n, info.broadcaster.outputsCount, CUDA_R_32I, + CUBLAS_GEMM_DEFAULT); + } + }; + + res.fetchOrStore(); + return {std::move(routine), workspace}; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh new file mode 100644 index 00000000..d0d0400a --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH +#define KERNEL_MATMUL_CUBLAS_KERNEL_HH + +#include "kernel/attributes/mat_mul_integer_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct MatMulIntegerCublas final : public Kernel { + MatMulIntegerInfo info; + + explicit MatMulIntegerCublas(decltype(info)) noexcept; + + static KernelBox build(decltype(info)) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MATMUL_CUBLAS_KERNEL_HH