Skip to content

Commit

Permalink
feat: add graph optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Dec 4, 2023
1 parent c162ed4 commit 0b87b19
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/01graph_topo/include/graph_topo/linked_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ namespace refactor::graph_topo {
TN const &info() const;
std::vector<Rc<Edge>> const &inputs() const;
std::vector<Rc<Edge>> const &outputs() const;
std::unordered_set<Rc<Node>> const &predecessors() const;
std::unordered_set<Rc<Node>> const &successors() const;
std::unordered_set<Rc<Node>> predecessors() const;
std::unordered_set<Rc<Node>> successors() const;
void connect(count_t, Rc<Edge>);
void disconnect(count_t);
void reconnect(Rc<Edge>, Rc<Edge>);
Expand Down Expand Up @@ -253,7 +253,7 @@ namespace refactor::graph_topo {
return _outputs;
}

LINKED_GRAPH_FN Node::predecessors() const->std::unordered_set<Rc<Node>> const & {
LINKED_GRAPH_FN Node::predecessors() const->std::unordered_set<Rc<Node>> {
std::unordered_set<Rc<Node>> ans;
for (auto const &e : _inputs) {
if (e->_source) {
Expand All @@ -263,7 +263,7 @@ namespace refactor::graph_topo {
return ans;
}

LINKED_GRAPH_FN Node::successors() const->std::unordered_set<Rc<Node>> const & {
LINKED_GRAPH_FN Node::successors() const->std::unordered_set<Rc<Node>> {
std::unordered_set<Rc<Node>> ans;
for (auto const &e : _outputs) {
for (auto const &[n, _] : e->_targets) {
Expand Down
12 changes: 12 additions & 0 deletions src/05computation/include/computation/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ namespace refactor::computation {
Graph(graph_topo::GraphTopo, std::vector<Node>, std::vector<Edge>) noexcept;

void layoutPermute();
void optimize();

kernel::Graph lower(Target) const;
auto internal() const -> decltype(_internal) const &;
};

//using GraphMutant = Graph;
class GraphMutant {
graph_topo::LinkedGraph<Node, Edge> _internal;


public:
explicit GraphMutant(Graph const &) noexcept;
auto internal() const -> decltype(_internal) const &;
auto internal() -> decltype(_internal) &;
};

}// namespace refactor::computation

#endif// COMPUTATION_GRAPH_H
40 changes: 40 additions & 0 deletions src/05computation/include/computation/pass/convert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef COMPUTATION_CONVERT_H
#define COMPUTATION_CONVERT_H

#include "computation/graph.h"

namespace refactor::computation {

class Converter {
public:
Converter() = default;
virtual ~Converter() = default;
virtual bool execute(const std::shared_ptr<GraphMutant> &) const = 0;
static Converter *get(std::string_view key) {
//fmt::println("{}", storage().size());
if (storage().find(key) != storage().end()) {
return storage().at(key).get();
}
return nullptr;
};
static void add(std::shared_ptr<Converter> converter, std::string_view key) {
storage().insert(std::make_pair(key, converter));
};
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> &storage() {
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> passStorage;
return passStorage;
}
};

template<class T>
class ConverterRegister {
public:
ConverterRegister(const char *claim) {
T *instance = new T;
Converter::add(std::shared_ptr<Converter>(instance), claim);
}
};

}// namespace refactor::computation

#endif// COMPUTATION_CONVERT_H
93 changes: 93 additions & 0 deletions src/05computation/include/computation/pass/kvcache_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef COMPUTATION_KVCACHE_ATTENTION_H
#define COMPUTATION_KVCACHE_ATTENTION_H

#include "computation/operators/concat.h"
#include "computation/operators/mat_mul.h"
#include "computation/operators/simple_binary.h"
#include "computation/operators/softmax.h"
#include "computation/operators/transpose.h"
#include "convert.h"
#include "graph.h"
#include "kernel/collectors/simple_binary.h"

namespace refactor::computation {

/* concat
|
transpose
|
matmul
|
div
|
softmax concat
\ /
matmul
*/
class KVCacheAttention : public Converter {
static size_t count = 0;

public:
virtual bool execute(std::shared_ptr<GraphMutant> &g) const override {
for (auto opMatch : g->internal().nodes()) {
size_t optype = opMatch->info().op->typeId();
if (optype != MatMul::typeId()) {
continue;
}
// match the matmul op
//auto matmulPredecessors = opMatch->predecessors();
if (opMatch->predecessors().size() != 2) {
continue;
}
auto matmulInputLeft = opMatch->predecessors()[0];
auto matmulInputRight = opMatch->predecessors()[1];
if (matmulInputLeft->info().op->opTypeId() != Softmax::typeId() ||
matmulInputRight->info().op->opTypeId() != Concat::typeId()) {
continue;
}
//auto softmaxPredecessors = matmulInputLeft->predecessors();
auto concatInputs = matmulInputRight->inputs();
if (matmulInputLeft->predecessors().size() != 1 || concatInputs.size() != 2) {
continue;
}
auto softmaxInput = matmulInputLeft->predecessors()[0];
if (softmaxInput->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Div)) {
continue;
}
if (softmaxInput->predecessors().size() != 1) {
continue;
}
divInput = softmaxInput->predecessors()[0];
if (divInput->info().op->opTypeId() != MatMul::typeId()) {
continue;
}
auto matmulInputs = divInput->inputs();
if (divInput->predecessors().size() != 2 || matmulInputs.size() != 2) {
continue;
}
auto matmul1InputLeft = divInput->predecessors()[0];
auto matmul1InputRight = divInput->predecessors()[1];
if (matmul1InputLeft->info().op->opTypeId() != SimpleBinary::typeId(SimpleBinaryType::Add) ||
matmul1InputRight->info().op->opTypeId() != Transpose::typeId()) {
continue;
}
if (matmul1InputRight->predecessors().size() != 1) {
continue;
}
transposeInput = matmul1InputRight->predecessors()[1];
if (transposeInput->info().op->opTypeId() != Concat::typeId()) {
continue;
}
auto concatInputs1 = transposeInput->inputs();
if (concatInputs1.size() != 2) {
continue;
}
//auto newNode = g->internal().pushNode();
}
return true;
};
};

}// namespace refactor::computation

#endif// COMPUTATION_KVCACHE_ATTENTION_H
58 changes: 58 additions & 0 deletions src/05computation/include/computation/pass/matmul_transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef COMPUTATION_MATMUL_TRANSPOSE_H
#define COMPUTATION_MATMUL_TRANSPOSE_H

#include "../graph.h"
#include "computation/operators/mat_mul.h"
#include "computation/operators/transpose.h"
#include "computation/pass/convert.h"

namespace refactor::computation {
class MatMulTransposeFuse : public Converter {
public:
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override {
for (auto opMatch : g->internal().nodes()) {
size_t optype = opMatch->info().op->opTypeId();
if (optype != MatMul::typeId()) {
continue;
}
//std::unordered_set<refactor::Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Node>> matmulPreOps = opMatch->predecessors();
//auto matmulPreOps = opMatch->predecessors();
if (opMatch->predecessors().size() == 0) {
continue;
}
for (size_t i = 0; i < opMatch->inputs().size(); ++i) {
if (auto preOp = opMatch->inputs()[i]->source();
preOp != nullptr && preOp->info().op->opTypeId() == Transpose::typeId()) {
auto transposeOp = dynamic_cast<Transpose *>(preOp->info().op.get());
auto matmulOp = dynamic_cast<MatMul *>(opMatch->info().op.get());
auto axis = transposeOp->perm;
bool flag = false;
if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) {
flag = true;
}
for (size_t index = 0; index < axis.size() - 2; ++index) {
if (index == axis[index]) {
continue;
}
flag = false;
break;
}
if (flag) {
if (i == 0) {
matmulOp->transA = !matmulOp->transA;
} else {
matmulOp->transB = !matmulOp->transB;
}
opMatch->reconnect(opMatch->inputs()[i], preOp->inputs()[0]);
g->internal().eraseNode(preOp);
}
}
}
}
return true;
};
};

//static ConverterRegister<MatMulTransposeFuse> __l("MatMulTransposeFuse");
}// namespace refactor::computation
#endif// COMPUTATION_MATMUL_TRANSPOSE_H
16 changes: 16 additions & 0 deletions src/05computation/include/computation/pass/pass_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef COMPUTATION_PASS_MANAGER_H
#define COMPUTATION_PASS_MANAGER_H
#include "convert.h"
#include "matmul_transpose.h"

namespace refactor::computation {

void register_() {
#define REGISTER(PASS, NAME) static ConverterRegister<PASS> __l("" #NAME);
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse)
};


}// namespace refactor::computation

#endif
32 changes: 32 additions & 0 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "computation/graph.h"
#include "computation/pass/convert.h"
#include "computation/pass/pass_manager.h"

namespace refactor::computation {

Expand Down Expand Up @@ -63,4 +65,34 @@ namespace refactor::computation {

auto Graph::internal() const -> decltype(_internal) const & { return _internal; }

void RunOptimizePass(std::vector<std::string_view> passes, const std::shared_ptr<GraphMutant> &g) {
for (auto pass : passes) {
auto convert = Converter::get(pass);
if (nullptr == convert) {
fmt::println("Can't find pass of {}.", pass);
continue;
}
bool valid = convert->execute(g);
if (!valid) {
fmt::println("Run {} Error", pass);
}
}
}

void Graph::optimize() {
auto graphMutant = GraphMutant(*this);
std::vector<std::string_view> passes = {
"MatMulTransposeFuse",
};
register_();//all pass insert
auto g = std::make_shared<GraphMutant>(graphMutant);
RunOptimizePass(passes, g);
_internal = g->internal();
}

GraphMutant::GraphMutant(Graph const &g) noexcept {
_internal = g.internal().linked();
}
auto GraphMutant::internal() const -> decltype(_internal) const & { return _internal; }
auto GraphMutant::internal() -> decltype(_internal) & { return _internal; }
}// namespace refactor::computation
57 changes: 57 additions & 0 deletions src/05computation/test/test_matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "computation/graph.h"
#include "computation/operators/mat_mul.h"
#include "computation/operators/transpose.h"
#include <gtest/gtest.h>

namespace refactor::computation {

refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestMatMulTransposeGraphBuild() {
absl::InlinedVector<uint32_t, 4> perm = {0, 1, 3, 2};
auto nodes = std::unordered_map<size_t, Node>{};
nodes[0] = Node{std::make_unique<Transpose>(perm), "transpose0"};
nodes[1] = Node{std::make_unique<Transpose>(perm), "transpose1"};
nodes[2] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul"};

auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others);
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others);
auto tensor2 = Tensor::share(DataType::F32, {1, 3, 5, 3}, LayoutType::Others);
auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others);
auto tensor4 = Tensor::share(DataType::F32, {2, 3, 5, 5}, LayoutType::Others);

return {
{
{0, {{0}, {2}}},
{1, {{1}, {3}}},
{2, {{2, 3}, {4}}},
},
{0, 1},// global inputs
{4}, // global outputs
std::move(nodes),
{
{0, {tensor0, "input0"}},
{1, {tensor1, "input1"}},
{2, {tensor2, "input0_transpose"}},
{3, {tensor3, "input1_transpose"}},
{4, {tensor4, "output"}},
},
};
}

TEST(Graph, MatMulTranspose) {
auto graphTopo = TestMatMulTransposeGraphBuild().build();
fmt::println("{}", graphTopo.topology.toString());
Graph g(std::move(graphTopo));
g.optimize();
auto const &g_ = g.internal().contiguous();
fmt::println("{}", g_.topology.toString());
fmt::println("Nodes info :");
for (size_t i = 0; i < g_.nodes.size(); ++i) {
fmt::println("{}. \"{}\"", i, g_.nodes[i].name);
}
fmt::println("\n Edges info :");
for (size_t i = 0; i < g_.edges.size(); ++i) {
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name,
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name());
}
}
}// namespace refactor::computation

0 comments on commit 0b87b19

Please sign in to comment.