Skip to content

Commit

Permalink
refactor(computation): 现在计算图下降到 kernel 时会直接从图结构删除无意义的节点
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 6, 2023
1 parent 1600518 commit e2ba442
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/00common/include/common/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace refactor {

bool empty() const noexcept { return end_ == begin_; }
size_t size() const noexcept { return end_ - begin_; }
t at(size_t i) const noexcept {
t at(size_t i) const {
ASSERT(i < size(), "Index out of range");
return operator[](i);
}
Expand Down
16 changes: 8 additions & 8 deletions src/01graph_topo/include/graph_topo/linked_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace refactor::graph_topo {
Rc<Node> pushNode(TN, std::vector<Rc<Edge>>);
void eraseNode(count_t);
void eraseNode(Rc<Node>);
size_t cleanup(bool useless(TE const &) = nullptr);
size_t cleanup(bool useful(TE const &) = nullptr);
bool sort();
};

Expand Down Expand Up @@ -86,6 +86,7 @@ namespace refactor::graph_topo {
node.disconnect(i);
}
for (auto const &out : node._outputs) {
ASSERT(out->_targets.empty(), "Output edge should not have targets");
out->_source = nullptr;
}
}
Expand Down Expand Up @@ -180,23 +181,22 @@ namespace refactor::graph_topo {
_nodes.erase(it);
}

LINKED_GRAPH_FN cleanup(bool useless(TE const &))->size_t {
LINKED_GRAPH_FN cleanup(bool useful(TE const &))->size_t {
std::unordered_set<Edge *> outputs;
outputs.reserve(_outputs.size());
std::transform(_outputs.begin(), _outputs.end(), std::inserter(outputs, outputs.end()), [](auto const &e) { return e.get(); });
auto useful = [&](Rc<Edge> const &e) {
return !e->_targets.empty() || // 还有节点连接到这个边
outputs.contains(e.get()) ||// 这个边是全图输出
!useless || // 不需要其他判断
!useless(e->_info); // 这个边其他原因有用
auto useful_ = [&](Rc<Edge> const &e) {
return !e->_targets.empty() || // 还有节点连接到这个边
outputs.contains(e.get()) || // 这个边是全图输出
(useful && useful(e->_info));// 这个边其他原因有用
};

auto before = _nodes.size();
while (true) {
auto endit = std::remove_if(
_nodes.begin(), _nodes.end(),
[&, this](auto &n) {
auto useless_ = std::none_of(n->_outputs.begin(), n->_outputs.end(), useful);
auto useless_ = std::none_of(n->_outputs.begin(), n->_outputs.end(), useful_);
if (useless_) { _cleanupNode(*n); }
return useless_;
});
Expand Down
3 changes: 3 additions & 0 deletions src/03runtime/include/runtime/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace refactor::runtime {
struct Edge {
Arc<hardware::Device::Blob> blob;
size_t stackOffset;
std::string name;
};

class Stream {
Expand All @@ -39,6 +40,8 @@ namespace refactor::runtime {
std::vector<Node>,
std::vector<Edge>,
decltype(_device));

decltype(_graph) const &graph() const noexcept { return _graph; }
void setData(count_t, void const *, size_t);
void setData(count_t, Arc<hardware::Device::Blob>);
bool getData(count_t, void *, size_t) const;
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace refactor::kernel {

for (auto i : range0_(edges_.size())) {
auto const &edge = _internal.edges[i];
edges_[i].name = edge.name;
if (edge.data) {
auto blob = device->malloc(edge.size);
blob->copyFromHost(edge.data->get<void>());
Expand Down
11 changes: 10 additions & 1 deletion src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ namespace refactor::computation {

auto modifier = graph_topo::InplaceModifier(graph.topology);
modifier.reconnect(identities);
return kernel::Graph(modifier.take(), std::move(nodes), std::move(edges));

auto temp = graph_topo::LinkedGraph(graph_topo::Graph{
modifier.take(),
std::move(nodes),
std::move(edges),
});
temp.cleanup();
auto [topo__, nodes__, edges__] = temp.intoGraph();

return kernel::Graph(std::move(topo__), std::move(nodes__), std::move(edges__));
}

auto Graph::internal() const -> decltype(_internal) const & { return _internal; }
Expand Down
16 changes: 8 additions & 8 deletions src/09python_ffi/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ namespace refactor::python_ffi {
}

void Executor::setInput(count_t i, pybind11::array data) {
auto globalInputs = _graph.internal().contiguous().topology.globalInputs();
ASSERT(i < globalInputs.size(), "input index out of range");
i = globalInputs[i];
i = _stream.graph().topology.globalInputs().at(i);

auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
auto const &name = _stream.graph().edges[i].name;
auto const &edges = _graph.internal().contiguous().edges;
auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor;
ASSERT(tensor.bytesSize() == static_cast<size_t>(data.nbytes()), "input size mismatch");
_stream.setData(i, data.data(), data.nbytes());
}

auto Executor::getOutput(count_t i) -> pybind11::array {
auto globalOutputs = _graph.internal().contiguous().topology.globalOutputs();
ASSERT(i < globalOutputs.size(), "output index out of range");
i = globalOutputs[i];
i = _stream.graph().topology.globalOutputs().at(i);

auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
auto const &name = _stream.graph().edges[i].name;
auto const &edges = _graph.internal().contiguous().edges;
auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor;
auto ans = pybind11::array(buildNumpyDType(tensor.dataType), std::move(tensor.shape));
_stream.getData(i, ans.mutable_data(), ans.nbytes());
return ans;
Expand Down

0 comments on commit e2ba442

Please sign in to comment.