Skip to content

Commit

Permalink
refactor(kernel): 为 Broadcaster 表示不需要广播明确语义
Browse files Browse the repository at this point in the history
feat(kernel): 实现 MatMulInteger 的信息抽取

Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 15, 2023
1 parent 6f11997 commit a8959f7
Show file tree
Hide file tree
Showing 16 changed files with 205 additions and 87 deletions.
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/attributes/broadcaster.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace refactor::kernel {
explicit Broadcaster(std::vector<slice_t<dim_t>>);
explicit Broadcaster(TensorRefs const &inputs);
void locate(dim_t k, dim_t ans[]) const noexcept;
bool needBroadcast() const noexcept;
};

}// namespace refactor::kernel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
#ifndef KERNEL_MATMUL_INFO_H
#define KERNEL_MATMUL_INFO_H
#ifndef KERNEL_MAT_MUL_INFO_H
#define KERNEL_MAT_MUL_INFO_H

#include "kernel/attributes/broadcaster.h"
#include "kernel/attributes/expand_info.h"
#include <variant>

namespace refactor::kernel {

struct MatMulInfo {
DataType dataType;
float alpha, beta;
bool transA, transB;
size_t m, k, n;
dim_t m, k, n;
// Expand operation info for biasd
std::optional<ExpandInfo> biasExpand;
// A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions
std::variant<Broadcaster, size_t> broadcasterOrBatch;
// A 2-directional broadcaster that deals with dimensions before the last 2 dimensions
Broadcaster broadcaster;

MatMulInfo(Tensor const &, Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Expand All @@ -24,4 +23,4 @@ namespace refactor::kernel {

}// namespace refactor::kernel

#endif// KERNEL_MATMUL_INFO_H
#endif// KERNEL_MAT_MUL_INFO_H
25 changes: 25 additions & 0 deletions src/04kernel/include/kernel/attributes/mat_mul_integer_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H
#define KERNEL_MAT_MUL_INTEGER_INFO_H

#include "kernel/attributes/broadcaster.h"

namespace refactor::kernel {

struct MatMulIntegerInfo {
struct Input {
bool signed_;
bool withZeroPoint;

Input(TensorRefs const &, size_t i) noexcept;
};

Input a, b;
dim_t m, k, n;
Broadcaster broadcaster;

explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept;
};

}// namespace refactor::kernel

#endif// KERNEL_MAT_MUL_INTEGER_INFO_H
4 changes: 4 additions & 0 deletions src/04kernel/src/attributes/broadcaster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,8 @@ namespace refactor::kernel {
}
}

bool Broadcaster::needBroadcast() const noexcept {
return !strides.empty();
}

}// namespace refactor::kernel
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#include "kernel/attributes/matmul_info.h"
#include <cstddef>
#include <numeric>
#include "kernel/attributes/mat_mul_info.h"

namespace refactor::kernel {

ExpandInfo buildBias(size_t m, size_t n,
ExpandInfo buildBias(dim_t m, dim_t n,
Tensor const &a,
Tensor const &b,
Tensor const &c) {
std::vector<dim_t> output(std::max(a.rank(), b.rank()));
auto it = output.rbegin();
*it++ = n;
*it++ = m;
for (auto da = static_cast<size_t>(a.rank() - 2),
db = static_cast<size_t>(b.rank() - 2);
for (auto da = static_cast<dim_t>(a.rank() - 2),
db = static_cast<dim_t>(b.rank() - 2);
auto i : range0_(output.size() - 2)) {
auto a_ = i < da ? a.shape[da - i - 1] : 1;
auto b_ = i < db ? b.shape[db - i - 1] : 1;
Expand All @@ -26,13 +24,6 @@ namespace refactor::kernel {
slice(output.data(), output.size()));
}

std::variant<Broadcaster, size_t> buildBroadcasterOrBatch(slice_t<dim_t> dimA, slice_t<dim_t> dimB) {
if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) {
return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies<size_t>());
}
return Broadcaster({dimA, dimB});
}

MatMulInfo::MatMulInfo(
Tensor const &a, Tensor const &b,
std::optional<std::reference_wrapper<Tensor const>> c,
Expand All @@ -44,7 +35,8 @@ namespace refactor::kernel {
k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]),
n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]),
biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt),
broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) {
broadcaster({slice(a.shape.data(), a.shape.size() - 2),
slice(b.shape.data(), b.shape.size() - 2)}) {
auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1];
ASSERT(k == kB, "MatMul: input shape not matched.");
}
Expand Down
26 changes: 26 additions & 0 deletions src/04kernel/src/attributes/mat_mul_integer_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "kernel/attributes/mat_mul_integer_info.h"

namespace refactor::kernel {

#define A (inputs[0].get().shape)
#define B (inputs[1].get().shape)

MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept
: signed_(inputs[i].get().dataType == DataType::I8),
withZeroPoint(false) {
if (inputs.size() > i + 2) {
auto const &t = inputs[i + 2].get();
withZeroPoint = t.rank() != 0 || !t.data || t.data->get<uint8_t>() != 0;
}
}

MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept
: a(inputs, 0),
b(inputs, 1),
m(A.rbegin()[1]),
k(A.rbegin()[0]),
n(B.rbegin()[0]),
broadcaster({slice(A.data(), A.size() - 2),
slice(B.data(), B.size() - 2)}) {}

}// namespace refactor::kernel
8 changes: 8 additions & 0 deletions src/04kernel/src/collectors/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
#include "kernel/collectors/mat_mul_integer.h"
#include "../../src/kernels/mat_mul_integer/cpu_kernel.hh"
#include "kernel/attributes/mat_mul_integer_info.h"

namespace refactor::kernel {

std::vector<KernelBox>
MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
MatMulIntegerInfo info(inputs);

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = MatMulIntegerCPU::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
break;
Expand Down
34 changes: 5 additions & 29 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cpu_kernel.hh"
#include "../expand/cpu_kernel.hh"
#include "../mat_mul_common/cpu_template.hpp"

namespace refactor::kernel {
using K = MatMulCPU;
Expand All @@ -8,7 +9,7 @@ namespace refactor::kernel {
K::MatMulCPU(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(MatMulInfo info) noexcept -> KernelBox {
auto K::build(decltype(info) info) noexcept -> KernelBox {
return info.dataType.isCpuNumberic()
? std::make_unique<K>(std::move(info))
: nullptr;
Expand All @@ -24,31 +25,6 @@ namespace refactor::kernel {
return "Performing MatMul using CPU";
}

template<class T>
struct MatMulCPUMetaData {
size_t M, K, N;
size_t strideA0, strideA1, strideB0, strideB1;
T alpha, beta;

/*
* 2D matrix multiplication: Y = a * A @ B + b * Y
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
void matrixMultiply(T const *A, T const *B, T *Y) const noexcept {
// #pragma omp parallel for
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
T sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < K; k++) {
sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1];
}
Y[i * N + j] = beta * Y[i * N + j] + alpha * sum;
}
}
}
};

template<class T>
static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace {
MatMulCPUMetaData const md{
Expand All @@ -70,8 +46,8 @@ namespace refactor::kernel {
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine)
: std::nullopt;

if (std::holds_alternative<Broadcaster>(info.broadcasterOrBatch)) {
return [broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch),
if (info.broadcaster.needBroadcast()) {
return [broadcaster = info.broadcaster,
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
Expand All @@ -87,7 +63,7 @@ namespace refactor::kernel {
}
};
} else {
return [batch = std::get<size_t>(info.broadcasterOrBatch),
return [batch = info.broadcaster.outputsCount,
stepY, stepA, stepB,
md, biasEx]//
(runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
Expand Down
7 changes: 3 additions & 4 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#ifndef KERNEL_MATMUL_CPU_KERNEL_HH
#define KERNEL_MATMUL_CPU_KERNEL_HH

#include "kernel/attributes/matmul_info.h"
#include "kernel/attributes/mat_mul_info.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct MatMulCPU final : public Kernel {
MatMulInfo info;

explicit MatMulCPU(MatMulInfo) noexcept;
explicit MatMulCPU(decltype(info)) noexcept;

static KernelBox build(MatMulInfo) noexcept;
static KernelBox build(decltype(info)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/mat_mul/cublas_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace refactor::kernel {
K::MatMulCublas(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(MatMulInfo info) noexcept -> KernelBox {
auto K::build(decltype(info) info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif
Expand Down
57 changes: 29 additions & 28 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,8 @@ namespace refactor::kernel {
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt;
// clang-format on
if (std::holds_alternative<size_t>(info.broadcasterOrBatch)) {
return [batch = std::get<size_t>(info.broadcasterOrBatch),
cudaDataType,
alpha, beta, tA, tB,
m, n, k,
strideA, strideB,
lda, ldb,
biasEx]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
// Call expand kernel to broadcast bias if bias is used
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
cublasGemmStridedBatchedEx(
res.fetchOrStore<CublasContext>()->handle,
tB, tA,
n, m, k,
&alpha,
b, cudaDataType, ldb, strideB,
a, cudaDataType, lda, strideA,
&beta, y, cudaDataType,
n, m * n, batch, cudaDataType,
CUBLAS_GEMM_DEFAULT);
};
} else {// if use boradcaster
return [broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch),
if (info.broadcaster.needBroadcast()) {
return [broadcaster = info.broadcaster,
cudaDataType,
alpha, beta, tA, tB,
m, n, k,
Expand Down Expand Up @@ -83,6 +57,33 @@ namespace refactor::kernel {
CUBLAS_GEMM_DEFAULT);
}
};

} else {
return [batch = info.broadcaster.outputsCount,
cudaDataType,
alpha, beta, tA, tB,
m, n, k,
strideA, strideB,
lda, ldb,
biasEx]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
// Call expand kernel to broadcast bias if bias is used
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
cublasGemmStridedBatchedEx(
res.fetchOrStore<CublasContext>()->handle,
tB, tA,
n, m, k,
&alpha,
b, cudaDataType, ldb, strideB,
a, cudaDataType, lda, strideA,
&beta, y, cudaDataType,
n, m * n, batch, cudaDataType,
CUBLAS_GEMM_DEFAULT);
};
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH
#define KERNEL_MATMUL_CUBLAS_KERNEL_HH

#include "kernel/attributes/matmul_info.h"
#include "kernel/attributes/mat_mul_info.h"
#include "kernel/kernel.h"

namespace refactor::kernel {

struct MatMulCublas final : public Kernel {
MatMulInfo info;

explicit MatMulCublas(MatMulInfo) noexcept;
explicit MatMulCublas(decltype(info)) noexcept;

static KernelBox build(MatMulInfo) noexcept;
static KernelBox build(decltype(info)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
Expand Down
33 changes: 33 additions & 0 deletions src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP
#define KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP

namespace refactor::kernel {

template<class T>
struct MatMulCPUMetaData {
size_t M, K, N;
size_t strideA0, strideA1, strideB0, strideB1;
T alpha, beta;

/*
* 2D matrix multiplication: Y = a * A @ B + b * Y
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
void matrixMultiply(T const *A, T const *B, T *Y) const noexcept {
// #pragma omp parallel for
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
T sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < K; k++) {
sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1];
}
Y[i * N + j] = beta * Y[i * N + j] + alpha * sum;
}
}
}
};

}// namespace refactor::kernel

#endif// KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP
Loading

0 comments on commit a8959f7

Please sign in to comment.