-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support conv1x1 to matmul pass
- Loading branch information
Showing
5 changed files
with
259 additions
and
1 deletion.
There are no files selected for viewing
144 changes: 144 additions & 0 deletions
144
src/05computation/include/computation/pass/conv_to_matmul.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#ifndef COMPUTATION_CONV_TO_MATMUL_H | ||
#define COMPUTATION_CONV_TO_MATMUL_H | ||
|
||
#include "../graph.h" | ||
#include "computation/operators/conv.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/reshape.h" | ||
#include "computation/operators/transpose.h" | ||
|
||
namespace refactor::computation { | ||
class ConvToMatmul : public Converter { | ||
|
||
public: | ||
/* | ||
* input weight | ||
* | | | ||
* | | | ||
* transpose transpose | ||
* | | | ||
* | | | ||
* reshape reshape | ||
* \ / | ||
* \ / | ||
* matmul | ||
* | | ||
* reshape | ||
* | | ||
* transpose | ||
* | | ||
* output | ||
*/ | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override { | ||
auto nodesList = g->internal().nodes(); | ||
size_t count = 0; | ||
for (auto opMatch : nodesList) { | ||
if (opMatch->info().op == nullptr) { | ||
continue; | ||
} | ||
size_t optype = opMatch->info().op->opTypeId(); | ||
if (optype != Conv::typeId()) { | ||
continue; | ||
} | ||
auto convOp = dynamic_cast<Conv *>(opMatch->info().op.get()); | ||
auto input = opMatch->inputs()[0]->info().tensor; | ||
auto weight = opMatch->inputs()[1]->info().tensor; | ||
auto shape = weight->shape; | ||
// judge conv is 1x1 convolution | ||
if (shape.size() != 4 || shape[2] != 1 || shape[3] != 1) { | ||
continue; | ||
} | ||
auto attr = convOp->attributes; | ||
auto poolAttrRank = attr.rank(); | ||
auto poolAttrDilation = attr.dilations(); | ||
auto poolAttrStride = attr.strides(); | ||
auto poolAttrPad = attr.pads(); | ||
bool flag = true; | ||
for (auto i : range0_(poolAttrRank)) { | ||
if (poolAttrDilation[i] != 1 || poolAttrStride[i] != 1) { | ||
flag = true; | ||
break; | ||
} | ||
if (poolAttrPad[i] != 0 || poolAttrPad[i + poolAttrRank] != 0) { | ||
flag = true; | ||
break; | ||
} | ||
} | ||
if (flag) { continue; } | ||
// create transpose op | ||
absl::InlinedVector<uint32_t, 4> | ||
perm1 = {0, 2, 3, 1}; | ||
Shape shape1 = {input->shape[0], input->shape[2], input->shape[3], input->shape[1]}; | ||
auto newTransposeOp1 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm1), fmt::format("ConvToMatmul_transpose1_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape1), fmt::format("ConvToMatmul_transpose1_{}_out", count)})}); | ||
newTransposeOp1->connect(0, opMatch->inputs()[0]); | ||
absl::InlinedVector<uint32_t, 4> perm2 = {1, 0, 2, 3}; | ||
Shape shape2 = {weight->shape[1], weight->shape[0], weight->shape[2], weight->shape[3]}; | ||
auto newTransposeOp2 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm2), fmt::format("ConvToMatmul_transpose2_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(weight->dataType, shape2), fmt::format("ConvToMatmul_transpose2_{}_out", count)})}); | ||
newTransposeOp2->connect(0, opMatch->inputs()[1]); | ||
// create reshape op | ||
Shape shape3 = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; | ||
Shape shape4 = {weight->shape[1], weight->shape[0]}; | ||
int64_t data1[2] = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; | ||
int64_t data2[2] = {weight->shape[1], weight->shape[0]}; | ||
auto [data1_, ptr1] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); | ||
auto [data2_, ptr2] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); | ||
ptr1 = &data1[0]; | ||
ptr2 = &data2[0]; | ||
auto newReshapeEdge1 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data1_), fmt::format("ConvToMatmul_reshape1_shape_{}", count)}); | ||
auto newReshapeEdge2 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data2_), fmt::format("ConvToMatmul_reshape2_shape_{}", count)}); | ||
auto newReshapeOp1 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape1_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape3), fmt::format("ConvToMatmul_reshape1_{}_out", count)})}); | ||
auto newReshapeOp2 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape2_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(weight->dataType, shape4), fmt::format("ConvToMatmul_reshape2_{}_out", count)})}); | ||
newReshapeOp1->connect(0, newTransposeOp1->outputs()[0]); | ||
newReshapeOp1->connect(1, newReshapeEdge1); | ||
newReshapeOp2->connect(0, newTransposeOp2->outputs()[0]); | ||
newReshapeOp2->connect(1, newReshapeEdge2); | ||
// create matmul op | ||
Shape shape5 = {input->shape[0] * input->shape[2] * input->shape[3], weight->shape[0]}; | ||
auto newMatMulOp = g->internal().pushNode( | ||
{std::make_unique<MatMul>(1.0, 1.0, false, false), fmt::format("ConvToMatmul_matmul_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape5), fmt::format("ConvToMatmul_matmul_{}_out", count)})}); | ||
newMatMulOp->connect(0, newReshapeOp1->outputs()[0]); | ||
newMatMulOp->connect(1, newReshapeOp2->outputs()[0]); | ||
// create reshape op | ||
Shape shape6 = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; | ||
int64_t data3[4] = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; | ||
auto [data3_, ptr3] = refactor::kernel::Blob::share(sizeof(int64_t) * 4); | ||
ptr3 = &data3[0]; | ||
auto newReshapeEdge3 = g->internal().shareEdge({Tensor::share(DataType::I64, {4}, LayoutType::Others, data3_), fmt::format("ConvToMatmul_reshape3_shape_{}", count)}); | ||
auto newReshapeOp3 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape3_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape6), fmt::format("ConvToMatmul_reshape3_{}_out", count)})}); | ||
newReshapeOp3->connect(0, newMatMulOp->outputs()[0]); | ||
newReshapeOp3->connect(1, newReshapeEdge3); | ||
// create transpose op | ||
absl::InlinedVector<uint32_t, 4> perm3 = {0, 3, 1, 2}; | ||
Shape shape7 = {input->shape[0], weight->shape[0], input->shape[2], input->shape[3]}; | ||
auto newTransposeOp3 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm3), fmt::format("ConvToMatmul_transpose3_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape7), fmt::format("ConvToMatmul_transpose3_{}_out", count)})}); | ||
newTransposeOp3->connect(0, newReshapeOp3->outputs()[0]); | ||
if (opMatch->outputs()[0]->targets().size() == 0) {// global output | ||
g->internal().replaceOutput(opMatch->outputs()[0], newTransposeOp3->outputs()[0]); | ||
} else { | ||
for (auto node : opMatch->outputs()[0]->targets()) { | ||
auto it = std::find(node->inputs().begin(), node->inputs().end(), opMatch->outputs()[0]); | ||
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], newTransposeOp3->outputs()[0]); | ||
} | ||
} | ||
g->internal().eraseNode(opMatch); | ||
count++; | ||
} | ||
return true; | ||
}; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
#endif// COMPUTATION_CONV_TO_MATMUL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#include "computation/graph.h" | ||
#include "computation/operators/conv.h" | ||
#include "computation/operators/simple_unary.h" | ||
#include <gtest/gtest.h> | ||
|
||
namespace refactor::computation { | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestConvToMatMulGraphBuild1() { | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
int64_t dilations[2] = {1, 1}; | ||
int64_t strides[2] = {2, 2}; | ||
int64_t pads[4] = {2, 2, 2, 2}; | ||
nodes[0] = Node{std::make_unique<Conv>(PoolAttributes(2, &dilations[0], &pads[0], &strides[0])), "conv"}; | ||
nodes[1] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0, 1}, {2}}}, | ||
{1, {{2}, {3}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{3}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input"}}, | ||
{1, {tensor1, "weight"}}, | ||
{2, {tensor2, "conv_output"}}, | ||
{3, {tensor3, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, ConvToMatMul1) { | ||
auto graphTopo = TestConvToMatMulGraphBuild1().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestConvToMatMulGraphBuild2() { | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<Conv>(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv0"}; | ||
nodes[1] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu0"}; | ||
nodes[2] = Node{std::make_unique<Conv>(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv1"}; | ||
nodes[3] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu1"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {4, 3, 1, 1}, LayoutType::Others); | ||
auto tensor5 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); | ||
auto tensor6 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0, 1}, {2}}}, | ||
{1, {{2}, {3}}}, | ||
{2, {{3, 4}, {5}}}, | ||
{3, {{5}, {6}}}, | ||
}, | ||
{0, 1, 4},// global inputs | ||
{6}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "weight0"}}, | ||
{2, {tensor2, "conv0_output"}}, | ||
{3, {tensor3, "relu0_output"}}, | ||
{4, {tensor4, "weight1"}}, | ||
{5, {tensor5, "conv1_output"}}, | ||
{6, {tensor6, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, ConvToMatMul2) { | ||
auto graphTopo = TestConvToMatMulGraphBuild2().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
}// namespace refactor::computation |
File renamed without changes.