Skip to content

Commit

Permalink
feat(computation): MatMulInteger 从 MatMul 分离
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 15, 2023
1 parent 7a37779 commit 6f11997
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 6 deletions.
19 changes: 19 additions & 0 deletions src/04kernel/include/kernel/collectors/mat_mul_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef KERNEL_MAT_MUL_INTEGER_H
#define KERNEL_MAT_MUL_INTEGER_H

#include "../collector.h"

namespace refactor::kernel {

struct MatMulIntegerCollector final : public InfoCollector {

constexpr MatMulIntegerCollector(decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_MAT_MUL_INTEGER_H
1 change: 0 additions & 1 deletion src/04kernel/src/collectors/mat_mul.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "kernel/collectors/mat_mul.h"
#include "../kernels/mat_mul/cpu_kernel.hh"
#include "../kernels/mat_mul/cublas_kernel.hh"
#include "common.h"
#include "kernel/attributes/matmul_info.h"

namespace refactor::kernel {
Expand Down
18 changes: 18 additions & 0 deletions src/04kernel/src/collectors/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "kernel/collectors/mat_mul_integer.h"

namespace refactor::kernel {
std::vector<KernelBox>
MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia:
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
3 changes: 1 addition & 2 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ namespace refactor::kernel {
: Kernel(), info(std::move(info_)) {}

auto K::build(MatMulInfo info) noexcept -> KernelBox {
static const std::unordered_set<decltype(DT::internal)> TYPE{DT::F32, DT::F64, DT::FP16};
#ifndef USE_CUDA
return nullptr;
#endif

return TYPE.contains(info.dataType)
return info.dataType.isIeee754() || info.dataType == DT::I8
? std::make_unique<K>(std::move(info))
: nullptr;
}
Expand Down
21 changes: 21 additions & 0 deletions src/05computation/include/computation/operators/mat_mul_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef COMPUTATION_MAT_MUL_INTEGER_H
#define COMPUTATION_MAT_MUL_INTEGER_H

#include "../operator.h"

namespace refactor::computation {

struct MatMulInteger final : public LayoutDependentOperator {

constexpr MatMulInteger() noexcept = default;

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// #ifndef COMPUTATION_MAT_MUL_INTEGER_H
20 changes: 20 additions & 0 deletions src/05computation/src/operators/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "computation/operators/mat_mul_integer.h"
#include "kernel/collectors/mat_mul_integer.h"

namespace refactor::computation {
using Op = MatMulInteger;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "MatMulInteger"; }
auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
return std::make_unique<kernel::MatMulIntegerCollector>(target);
}
auto Op::serialize() const noexcept -> std::string {
return "MatMulInteger()";
}

}// namespace refactor::computation
6 changes: 3 additions & 3 deletions src/07onnx/src/operators/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "mat_mul_integer.hh"
#include "common.h"
#include "computation/operators/mat_mul.h"
#include "computation/operators/mat_mul_integer.h"
#include <unordered_set>

namespace refactor::onnx {
Expand Down Expand Up @@ -95,8 +95,8 @@ namespace refactor::onnx {
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::MatMul;
return std::make_unique<Op_>(1.0, 1.0, false, false);
using Op_ = computation::MatMulInteger;
return std::make_unique<Op_>();
}

}// namespace refactor::onnx

0 comments on commit 6f11997

Please sign in to comment.