-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(kernel): 为 Broadcaster 表示不需要广播明确语义
feat(kernel): 实现 MatMulInteger 的信息抽取 Signed-off-by: YdrMaster <ydrml@hotmail.com>
- Loading branch information
Showing
16 changed files
with
205 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
src/04kernel/include/kernel/attributes/mat_mul_integer_info.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H | ||
#define KERNEL_MAT_MUL_INTEGER_INFO_H | ||
|
||
#include "kernel/attributes/broadcaster.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct MatMulIntegerInfo { | ||
struct Input { | ||
bool signed_; | ||
bool withZeroPoint; | ||
|
||
Input(TensorRefs const &, size_t i) noexcept; | ||
}; | ||
|
||
Input a, b; | ||
dim_t m, k, n; | ||
Broadcaster broadcaster; | ||
|
||
explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_MAT_MUL_INTEGER_INFO_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#include "kernel/attributes/mat_mul_integer_info.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
#define A (inputs[0].get().shape) | ||
#define B (inputs[1].get().shape) | ||
|
||
MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept | ||
: signed_(inputs[i].get().dataType == DataType::I8), | ||
withZeroPoint(false) { | ||
if (inputs.size() > i + 2) { | ||
auto const &t = inputs[i + 2].get(); | ||
withZeroPoint = t.rank() != 0 || !t.data || t.data->get<uint8_t>() != 0; | ||
} | ||
} | ||
|
||
MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept | ||
: a(inputs, 0), | ||
b(inputs, 1), | ||
m(A.rbegin()[1]), | ||
k(A.rbegin()[0]), | ||
n(B.rbegin()[0]), | ||
broadcaster({slice(A.data(), A.size() - 2), | ||
slice(B.data(), B.size() - 2)}) {} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP | ||
#define KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP | ||
|
||
namespace refactor::kernel { | ||
|
||
template<class T> | ||
struct MatMulCPUMetaData { | ||
size_t M, K, N; | ||
size_t strideA0, strideA1, strideB0, strideB1; | ||
T alpha, beta; | ||
|
||
/* | ||
* 2D matrix multiplication: Y = a * A @ B + b * Y | ||
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. | ||
*/ | ||
void matrixMultiply(T const *A, T const *B, T *Y) const noexcept { | ||
// #pragma omp parallel for | ||
for (size_t i = 0; i < M; i++) { | ||
for (size_t j = 0; j < N; j++) { | ||
T sum = 0; | ||
// #pragma omp simd reduction(+ : sum) | ||
for (size_t k = 0; k < K; k++) { | ||
sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1]; | ||
} | ||
Y[i * N + j] = beta * Y[i * N + j] + alpha * sum; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP |
Oops, something went wrong.