-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
312 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef COMPUTATION_CONVERT_H | ||
#define COMPUTATION_CONVERT_H | ||
|
||
#include "computation/graph.h" | ||
|
||
namespace refactor::computation { | ||
|
||
class Converter { | ||
public: | ||
Converter() = default; | ||
virtual ~Converter() = default; | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &) const = 0; | ||
static Converter *get(std::string_view key) { | ||
//fmt::println("{}", storage().size()); | ||
if (storage().find(key) != storage().end()) { | ||
return storage().at(key).get(); | ||
} | ||
return nullptr; | ||
}; | ||
static void add(std::shared_ptr<Converter> converter, std::string_view key) { | ||
storage().insert(std::make_pair(key, converter)); | ||
}; | ||
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> &storage() { | ||
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> passStorage; | ||
return passStorage; | ||
} | ||
}; | ||
|
||
template<class T> | ||
class ConverterRegister { | ||
public: | ||
ConverterRegister(const char *claim) { | ||
T *instance = new T; | ||
Converter::add(std::shared_ptr<Converter>(instance), claim); | ||
} | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_CONVERT_H |
93 changes: 93 additions & 0 deletions
93
src/05computation/include/computation/pass/kvcache_attention.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#ifndef COMPUTATION_KVCACHE_ATTENTION_H | ||
#define COMPUTATION_KVCACHE_ATTENTION_H | ||
|
||
#include "computation/operators/concat.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/simple_binary.h" | ||
#include "computation/operators/softmax.h" | ||
#include "computation/operators/transpose.h" | ||
#include "convert.h" | ||
#include "graph.h" | ||
#include "kernel/collectors/simple_binary.h" | ||
|
||
namespace refactor::computation { | ||
|
||
/* concat | ||
| | ||
transpose | ||
| | ||
matmul | ||
| | ||
div | ||
| | ||
softmax concat | ||
\ / | ||
matmul | ||
*/ | ||
class KVCacheAttention : public Converter { | ||
static size_t count = 0; | ||
|
||
public: | ||
virtual bool execute(std::shared_ptr<GraphMutant> &g) const override { | ||
for (auto opMatch : g->internal().nodes()) { | ||
size_t optype = opMatch->info().op->typeId(); | ||
if (optype != MatMul::typeId()) { | ||
continue; | ||
} | ||
// match the matmul op | ||
//auto matmulPredecessors = opMatch->predecessors(); | ||
if (opMatch->predecessors().size() != 2) { | ||
continue; | ||
} | ||
auto matmulInputLeft = opMatch->predecessors()[0]; | ||
auto matmulInputRight = opMatch->predecessors()[1]; | ||
if (matmulInputLeft->info().op->opTypeId() != Softmax::typeId() || | ||
matmulInputRight->info().op->opTypeId() != Concat::typeId()) { | ||
continue; | ||
} | ||
//auto softmaxPredecessors = matmulInputLeft->predecessors(); | ||
auto concatInputs = matmulInputRight->inputs(); | ||
if (matmulInputLeft->predecessors().size() != 1 || concatInputs.size() != 2) { | ||
continue; | ||
} | ||
auto softmaxInput = matmulInputLeft->predecessors()[0]; | ||
if (softmaxInput->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Div)) { | ||
continue; | ||
} | ||
if (softmaxInput->predecessors().size() != 1) { | ||
continue; | ||
} | ||
divInput = softmaxInput->predecessors()[0]; | ||
if (divInput->info().op->opTypeId() != MatMul::typeId()) { | ||
continue; | ||
} | ||
auto matmulInputs = divInput->inputs(); | ||
if (divInput->predecessors().size() != 2 || matmulInputs.size() != 2) { | ||
continue; | ||
} | ||
auto matmul1InputLeft = divInput->predecessors()[0]; | ||
auto matmul1InputRight = divInput->predecessors()[1]; | ||
if (matmul1InputLeft->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Add) || | ||
matmul1InputRight->info().op->opTypeId() != Transpose::typeId()) { | ||
continue; | ||
} | ||
if (matmul1InputRight->predecessors().size() != 1) { | ||
continue; | ||
} | ||
transposeInput = matmul1InputRight->predecessors()[1]; | ||
if (transposeInput->info().op->opTypeId() != Concat::typeId()) { | ||
continue; | ||
} | ||
auto concatInputs1 = transposeInput->inputs(); | ||
if (concatInputs1.size() != 2) { | ||
continue; | ||
} | ||
//auto newNode = g->internal().pushNode(); | ||
} | ||
return true; | ||
}; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_KVCACHE_ATTENTION_H |
58 changes: 58 additions & 0 deletions
58
src/05computation/include/computation/pass/matmul_transpose.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#ifndef COMPUTATION_MATMUL_TRANSPOSE_H | ||
#define COMPUTATION_MATMUL_TRANSPOSE_H | ||
|
||
#include "../graph.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/transpose.h" | ||
#include "computation/pass/convert.h" | ||
|
||
namespace refactor::computation { | ||
class MatMulTransposeFuse : public Converter { | ||
public: | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override { | ||
for (auto opMatch : g->internal().nodes()) { | ||
size_t optype = opMatch->info().op->opTypeId(); | ||
if (optype != MatMul::typeId()) { | ||
continue; | ||
} | ||
//std::unordered_set<refactor::Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Node>> matmulPreOps = opMatch->predecessors(); | ||
//auto matmulPreOps = opMatch->predecessors(); | ||
if (opMatch->predecessors().size() == 0) { | ||
continue; | ||
} | ||
for (size_t i = 0; i < opMatch->inputs().size(); ++i) { | ||
if (auto preOp = opMatch->inputs()[i]->source(); | ||
preOp != nullptr && preOp->info().op->opTypeId() == Transpose::typeId()) { | ||
auto transposeOp = dynamic_cast<Transpose *>(preOp->info().op.get()); | ||
auto matmulOp = dynamic_cast<MatMul *>(opMatch->info().op.get()); | ||
auto axis = transposeOp->perm; | ||
bool flag = false; | ||
if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { | ||
flag = true; | ||
} | ||
for (size_t index = 0; index < axis.size() - 2; ++index) { | ||
if (index == axis[index]) { | ||
continue; | ||
} | ||
flag = false; | ||
break; | ||
} | ||
if (flag) { | ||
if (i == 0) { | ||
matmulOp->transA = !matmulOp->transA; | ||
} else { | ||
matmulOp->transB = !matmulOp->transB; | ||
} | ||
opMatch->reconnect(opMatch->inputs()[i], preOp->inputs()[0]); | ||
g->internal().eraseNode(preOp); | ||
} | ||
} | ||
} | ||
} | ||
return true; | ||
}; | ||
}; | ||
|
||
//static ConverterRegister<MatMulTransposeFuse> __l("MatMulTransposeFuse"); | ||
}// namespace refactor::computation | ||
#endif// COMPUTATION_MATMUL_TRANSPOSE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef COMPUTATION_PASS_MANAGER_H | ||
#define COMPUTATION_PASS_MANAGER_H | ||
#include "convert.h" | ||
#include "matmul_transpose.h" | ||
|
||
namespace refactor::computation { | ||
|
||
void register_() { | ||
#define REGISTER(PASS, NAME) static ConverterRegister<PASS> __l("" #NAME); | ||
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) | ||
}; | ||
|
||
|
||
}// namespace refactor::computation | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#include "computation/graph.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/transpose.h" | ||
#include <gtest/gtest.h> | ||
|
||
namespace refactor::computation { | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestMatMulTransposeGraphBuild() { | ||
absl::InlinedVector<uint32_t, 4> perm = {0, 1, 3, 2}; | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<Transpose>(perm), "transpose0"}; | ||
nodes[1] = Node{std::make_unique<Transpose>(perm), "transpose1"}; | ||
nodes[2] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 3, 5, 3}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {2, 3, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0}, {2}}}, | ||
{1, {{1}, {3}}}, | ||
{2, {{2, 3}, {4}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{4}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "input1"}}, | ||
{2, {tensor2, "input0_transpose"}}, | ||
{3, {tensor3, "input1_transpose"}}, | ||
{4, {tensor4, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, MatMulTranspose) { | ||
auto graphTopo = TestMatMulTransposeGraphBuild().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
}// namespace refactor::computation |