diff --git a/src/01graph_topo/include/graph_topo/linked_graph.hpp b/src/01graph_topo/include/graph_topo/linked_graph.hpp index adb9c58fb..1793034d1 100644 --- a/src/01graph_topo/include/graph_topo/linked_graph.hpp +++ b/src/01graph_topo/include/graph_topo/linked_graph.hpp @@ -56,8 +56,8 @@ namespace refactor::graph_topo { TN const &info() const; std::vector> const &inputs() const; std::vector> const &outputs() const; - std::unordered_set> const &predecessors() const; - std::unordered_set> const &successors() const; + std::unordered_set> predecessors() const; + std::unordered_set> successors() const; void connect(count_t, Rc); void disconnect(count_t); void reconnect(Rc, Rc); @@ -253,7 +253,7 @@ namespace refactor::graph_topo { return _outputs; } - LINKED_GRAPH_FN Node::predecessors() const->std::unordered_set> const & { + LINKED_GRAPH_FN Node::predecessors() const->std::unordered_set> { std::unordered_set> ans; for (auto const &e : _inputs) { if (e->_source) { @@ -263,7 +263,7 @@ namespace refactor::graph_topo { return ans; } - LINKED_GRAPH_FN Node::successors() const->std::unordered_set> const & { + LINKED_GRAPH_FN Node::successors() const->std::unordered_set> { std::unordered_set> ans; for (auto const &e : _outputs) { for (auto const &[n, _] : e->_targets) { diff --git a/src/05computation/include/computation/graph.h b/src/05computation/include/computation/graph.h index 733db5c7a..1104f8c51 100644 --- a/src/05computation/include/computation/graph.h +++ b/src/05computation/include/computation/graph.h @@ -26,11 +26,23 @@ namespace refactor::computation { Graph(graph_topo::GraphTopo, std::vector, std::vector) noexcept; void layoutPermute(); + void optimize(); kernel::Graph lower(Target) const; auto internal() const -> decltype(_internal) const &; }; + //using GraphMutant = Graph; + class GraphMutant { + graph_topo::LinkedGraph _internal; + + + public: + explicit GraphMutant(Graph const &) noexcept; + auto internal() const -> decltype(_internal) const &; + auto internal() -> decltype(_internal) &; + }; + }// namespace refactor::computation #endif// COMPUTATION_GRAPH_H diff --git a/src/05computation/include/computation/pass/convert.h b/src/05computation/include/computation/pass/convert.h new file mode 100644 index 000000000..e3b4528bc --- /dev/null +++ b/src/05computation/include/computation/pass/convert.h @@ -0,0 +1,40 @@ +#ifndef COMPUTATION_CONVERT_H +#define COMPUTATION_CONVERT_H + +#include "computation/graph.h" + +namespace refactor::computation { + + class Converter { + public: + Converter() = default; + virtual ~Converter() = default; + virtual bool execute(const std::shared_ptr &) const = 0; + static Converter *get(std::string_view key) { + //fmt::println("{}", storage().size()); + if (storage().find(key) != storage().end()) { + return storage().at(key).get(); + } + return nullptr; + }; + static void add(std::shared_ptr converter, std::string_view key) { + storage().insert(std::make_pair(key, converter)); + }; + static std::unordered_map> &storage() { + static std::unordered_map> passStorage; + return passStorage; + } + }; + + template + class ConverterRegister { + public: + ConverterRegister(const char *claim) { + T *instance = new T; + Converter::add(std::shared_ptr(instance), claim); + } + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_CONVERT_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/kvcache_attention.h b/src/05computation/include/computation/pass/kvcache_attention.h new file mode 100644 index 000000000..688467d3b --- /dev/null +++ b/src/05computation/include/computation/pass/kvcache_attention.h @@ -0,0 +1,93 @@ +#ifndef COMPUTATION_KVCACHE_ATTENTION_H +#define COMPUTATION_KVCACHE_ATTENTION_H + +#include "computation/operators/concat.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/simple_binary.h" +#include "computation/operators/softmax.h" +#include "computation/operators/transpose.h" +#include "convert.h" +#include "graph.h" +#include "kernel/collectors/simple_binary.h" + +namespace refactor::computation { + + /* concat + | + transpose + | + matmul + | + div + | + softmax concat + \ / + matmul + */ + class KVCacheAttention : public Converter { + static size_t count = 0; + + public: + virtual bool execute(std::shared_ptr &g) const override { + for (auto opMatch : g->internal().nodes()) { + size_t optype = opMatch->info().op->typeId(); + if (optype != MatMul::typeId()) { + continue; + } + // match the matmul op + //auto matmulPredecessors = opMatch->predecessors(); + if (opMatch->predecessors().size() != 2) { + continue; + } + auto matmulInputLeft = opMatch->predecessors()[0]; + auto matmulInputRight = opMatch->predecessors()[1]; + if (matmulInputLeft->info().op->opTypeId() != Softmax::typeId() || + matmulInputRight->info().op->opTypeId() != Concat::typeId()) { + continue; + } + //auto softmaxPredecessors = matmulInputLeft->predecessors(); + auto concatInputs = matmulInputRight->inputs(); + if (matmulInputLeft->predecessors().size() != 1 || concatInputs.size() != 2) { + continue; + } + auto softmaxInput = matmulInputLeft->predecessors()[0]; + if (softmaxInput->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Div)) { + continue; + } + if (softmaxInput->predecessors().size() != 1) { + continue; + } + divInput = softmaxInput->predecessors()[0]; + if (divInput->info().op->opTypeId() != MatMul::typeId()) { + continue; + } + auto matmulInputs = divInput->inputs(); + if (divInput->predecessors().size() != 2 || matmulInputs.size() != 2) { + continue; + } + auto matmul1InputLeft = divInput->predecessors()[0]; + auto matmul1InputRight = divInput->predecessors()[1]; + if (matmul1InputLeft->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Add) || + matmul1InputRight->info().op->opTypeId() != Transpose::typeId()) { + continue; + } + if (matmul1InputRight->predecessors().size() != 1) { + continue; + } + transposeInput = matmul1InputRight->predecessors()[1]; + if (transposeInput->info().op->opTypeId() != Concat::typeId()) { + continue; + } + auto concatInputs1 = transposeInput->inputs(); + if (concatInputs1.size() != 2) { + continue; + } + //auto newNode = g->internal().pushNode(); + } + return true; + }; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_KVCACHE_ATTENTION_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/matmul_transpose.h b/src/05computation/include/computation/pass/matmul_transpose.h new file mode 100644 index 000000000..ffc8d72c0 --- /dev/null +++ b/src/05computation/include/computation/pass/matmul_transpose.h @@ -0,0 +1,58 @@ +#ifndef COMPUTATION_MATMUL_TRANSPOSE_H +#define COMPUTATION_MATMUL_TRANSPOSE_H + +#include "../graph.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/transpose.h" +#include "computation/pass/convert.h" + +namespace refactor::computation { + class MatMulTransposeFuse : public Converter { + public: + virtual bool execute(const std::shared_ptr &g) const override { + for (auto opMatch : g->internal().nodes()) { + size_t optype = opMatch->info().op->opTypeId(); + if (optype != MatMul::typeId()) { + continue; + } + //std::unordered_set::Node>> matmulPreOps = opMatch->predecessors(); + //auto matmulPreOps = opMatch->predecessors(); + if (opMatch->predecessors().size() == 0) { + continue; + } + for (size_t i = 0; i < opMatch->inputs().size(); ++i) { + if (auto preOp = opMatch->inputs()[i]->source(); + preOp != nullptr && preOp->info().op->opTypeId() == Transpose::typeId()) { + auto transposeOp = dynamic_cast(preOp->info().op.get()); + auto matmulOp = dynamic_cast(opMatch->info().op.get()); + auto axis = transposeOp->perm; + bool flag = false; + if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { + flag = true; + } + for (size_t index = 0; index < axis.size() - 2; ++index) { + if (index == axis[index]) { + continue; + } + flag = false; + break; + } + if (flag) { + if (i == 0) { + matmulOp->transA = !matmulOp->transA; + } else { + matmulOp->transB = !matmulOp->transB; + } + opMatch->reconnect(opMatch->inputs()[i], preOp->inputs()[0]); + g->internal().eraseNode(preOp); + } + } + } + } + return true; + }; + }; + + //static ConverterRegister __l("MatMulTransposeFuse"); +}// namespace refactor::computation +#endif// COMPUTATION_MATMUL_TRANSPOSE_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/pass_manager.h b/src/05computation/include/computation/pass/pass_manager.h new file mode 100644 index 000000000..3976b0717 --- /dev/null +++ b/src/05computation/include/computation/pass/pass_manager.h @@ -0,0 +1,16 @@ +#ifndef COMPUTATION_PASS_MANAGER_H +#define COMPUTATION_PASS_MANAGER_H +#include "convert.h" +#include "matmul_transpose.h" + +namespace refactor::computation { + + void register_() { +#define REGISTER(PASS, NAME) static ConverterRegister __l("" #NAME); + REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) + }; + + +}// namespace refactor::computation + +#endif \ No newline at end of file diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index bb042229b..03e8e6157 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -1,4 +1,6 @@ #include "computation/graph.h" +#include "computation/pass/convert.h" +#include "computation/pass/pass_manager.h" namespace refactor::computation { @@ -63,4 +65,34 @@ namespace refactor::computation { auto Graph::internal() const -> decltype(_internal) const & { return _internal; } + void RunOptimizePass(std::vector passes, const std::shared_ptr &g) { + for (auto pass : passes) { + auto convert = Converter::get(pass); + if (nullptr == convert) { + fmt::println("Can't find pass of {}.", pass); + continue; + } + bool valid = convert->execute(g); + if (!valid) { + fmt::println("Run {} Error", pass); + } + } + } + + void Graph::optimize() { + auto graphMutant = GraphMutant(*this); + std::vector passes = { + "MatMulTransposeFuse", + }; + register_();//all pass insert + auto g = std::make_shared(graphMutant); + RunOptimizePass(passes, g); + _internal = g->internal(); + } + + GraphMutant::GraphMutant(Graph const &g) noexcept { + _internal = g.internal().linked(); + } + auto GraphMutant::internal() const -> decltype(_internal) const & { return _internal; } + auto GraphMutant::internal() -> decltype(_internal) & { return _internal; } }// namespace refactor::computation diff --git a/src/05computation/test/test_matmul.cpp b/src/05computation/test/test_matmul.cpp new file mode 100644 index 000000000..4aca39ca0 --- /dev/null +++ b/src/05computation/test/test_matmul.cpp @@ -0,0 +1,57 @@ +#include "computation/graph.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/transpose.h" +#include + +namespace refactor::computation { + + refactor::graph_topo::Builder TestMatMulTransposeGraphBuild() { + absl::InlinedVector perm = {0, 1, 3, 2}; + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_unique(perm), "transpose0"}; + nodes[1] = Node{std::make_unique(perm), "transpose1"}; + nodes[2] = Node{std::make_unique(1.0, 1.0, false, false), "matmul"}; + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {1, 3, 5, 3}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {2, 3, 5, 5}, LayoutType::Others); + + return { + { + {0, {{0}, {2}}}, + {1, {{1}, {3}}}, + {2, {{2, 3}, {4}}}, + }, + {0, 1},// global inputs + {4}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "input1"}}, + {2, {tensor2, "input0_transpose"}}, + {3, {tensor3, "input1_transpose"}}, + {4, {tensor4, "output"}}, + }, + }; + } + + TEST(Graph, MatMulTranspose) { + auto graphTopo = TestMatMulTransposeGraphBuild().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + } +}// namespace refactor::computation