diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index 4e3bfe30..6b8e58a1 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -49,7 +49,7 @@ namespace refactor::runtime { auto prepare() -> std::vector; void run(); auto bench(void (*sync)()) -> std::vector; - void trace(std::function); + void trace(std::function); }; }// namespace refactor::runtime diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index 2c6fe884..05228da9 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -96,7 +96,6 @@ namespace refactor::runtime { } void Stream::run() { - auto map = [this](auto i) { return _internal.edges[i](*_stack); }; std::vector buffer(16); for (auto const [nodeIdx, i, o] : _internal.topology) { auto [inputs, outputs] = collectAddress(*_stack, _internal.edges, buffer, i, o); @@ -105,7 +104,6 @@ namespace refactor::runtime { } auto Stream::bench(void (*sync)()) -> std::vector { - auto map = [this](auto i) { return _internal.edges[i](*_stack); }; std::vector buffer(16); std::vector ans(_internal.nodes.size()); for (auto const [nodeIdx, i, o] : _internal.topology) { @@ -119,8 +117,7 @@ namespace refactor::runtime { return ans; } - void Stream::trace(std::function record) { - auto map = [this](auto i) { return _internal.edges[i](*_stack); }; + void Stream::trace(std::function record) { std::vector buffer(16); for (auto const [nodeIdx, i, o] : _internal.topology) { auto [inputs, outputs] = collectAddress(*_stack, _internal.edges, buffer, i, o); diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index cc25b712..20598ab4 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -34,7 +34,7 @@ namespace refactor::onnx { void register_() { -// clang-format off + // clang-format off #define REGISTER(NAME, CLASS) Operator::register_("onnx::" #NAME) REGISTER(BatchNormalization, BatchNormalization); REGISTER(Cast , Cast ); diff --git a/src/08communication/src/operators.cpp b/src/08communication/src/operators.cpp index 9525588d..3f47f256 100644 --- a/src/08communication/src/operators.cpp +++ b/src/08communication/src/operators.cpp @@ -6,7 +6,7 @@ namespace refactor::communication { using namespace frontend; void register_() { -// clang-format off + // clang-format off #define REGISTER(NAME, CLASS) Operator::register_("onnx::" #NAME) REGISTER(AllReduceAvg , AllReduce); REGISTER(AllReduceSum , AllReduce); diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index 0455fc28..7018f263 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -51,11 +51,13 @@ namespace refactor::python_ffi { void Executor::trace(std::string path_) { namespace fs = std::filesystem; + auto path = fs::path(std::move(path_)); fs::create_directories(path); ASSERT(fs::is_directory(path), "Failed to create \"{}\"", path.c_str()); + auto it = _graph.internal().contiguous().topology.begin(); - _stream.trace([&](count_t nodeIdx, void const **inputs, void **outputs) { + _stream.trace([&](count_t nodeIdx, void const *const *inputs, void const *const *outputs) { auto [nodeIdx_, i_, o_] = *it++; ASSERT(nodeIdx_ == nodeIdx, "node index mismatch"); auto nodeName = _graph.internal().contiguous().nodes[nodeIdx].name; @@ -63,8 +65,9 @@ namespace refactor::python_ffi { std::replace(nodeName.begin(), nodeName.end(), '.', '-'); std::vector buffer; - auto fn = [&](char dir, count_t idx, computation::Edge const &edge, void const *ptr) { - if (!ptr) { return; } + auto fn = [&](char dir, count_t idx, count_t edgeIdx, void const *const *addresses) { + if (!addresses[idx]) { return; } + auto edge = _graph.internal().contiguous().edges[edgeIdx]; auto size = edge.tensor->bytesSize(); buffer.resize(size); @@ -75,14 +78,13 @@ namespace refactor::python_ffi { fs::remove(file); std::ofstream os(file, std::ios::binary); #ifdef USE_CUDA - kernel::cuda::copyOut(buffer.data(), ptr, size); + kernel::cuda::copyOut(buffer.data(), addresses[idx], size); #endif os.write(buffer.data(), size); }; - auto const &edges = _graph.internal().contiguous().edges; - for (auto i : range0_(i_.size())) { fn('i', i, edges[i_[i]], inputs[i]); } - for (auto i : range0_(o_.size())) { fn('o', i, edges[o_[i]], outputs[i]); } + for (auto i : range0_(i_.size())) { fn('i', i, i_[i], inputs); } + for (auto i : range0_(o_.size())) { fn('o', i, o_[i], outputs); } }); }