diff --git a/src/02hardware/include/hardware/device_manager.h b/src/02hardware/include/hardware/device_manager.h index cea169de..7e008c10 100644 --- a/src/02hardware/include/hardware/device_manager.h +++ b/src/02hardware/include/hardware/device_manager.h @@ -5,6 +5,7 @@ namespace refactor::hardware::device { + Arc fetch(Device::Type); Arc fetch(Device::Type, int32_t card); Arc init(Device::Type, int32_t card, std::string_view args); diff --git a/src/02hardware/src/device_manager.cpp b/src/02hardware/src/device_manager.cpp index 031da822..97b91ac7 100644 --- a/src/02hardware/src/device_manager.cpp +++ b/src/02hardware/src/device_manager.cpp @@ -5,41 +5,43 @@ namespace refactor::hardware::device { - static std::unordered_map> DEVICES; - union DeviceKey { - struct { - Device::Type type; - int32_t card; - }; - int64_t i64; - }; - constexpr static DeviceKey CPU_KEY{{Device::Type::Cpu, 0}}; + static Arc cpu() { + static auto cpu = std::make_shared(); + return cpu; + } + static std::unordered_map>> DEVICES; + constexpr static auto CPU_KEY = static_cast(Device::Type::Cpu); + Arc fetch(Device::Type type) { + auto type_ = static_cast(type); + if (type_ == CPU_KEY) { return cpu(); } + if (auto kind = DEVICES.find(type_); kind != DEVICES.end()) { + if (auto it = kind->second.begin(); it != kind->second.end()) { + return it->second; + } + } + return init(type, 0, ""); + } Arc fetch(Device::Type type, int32_t card) { - if (type == decltype(type)::Cpu) { - auto it = DEVICES.find(CPU_KEY.i64); - return it != DEVICES.end() - ? it->second - : DEVICES.emplace(CPU_KEY.i64, std::make_shared()).first->second; + auto type_ = static_cast(type); + if (type_ == CPU_KEY) { return cpu(); } + if (auto kind = DEVICES.find(type_); kind != DEVICES.end()) { + if (auto it = kind->second.find(card); it != kind->second.end()) { + return it->second; + } } - auto it = DEVICES.find(DeviceKey{{type, card}}.i64); - return it != DEVICES.end() ? it->second : nullptr; + return nullptr; } Arc init(Device::Type type, int32_t card, std::string_view args) { - if (type == decltype(type)::Cpu) { - auto it = DEVICES.find(CPU_KEY.i64); - return it != DEVICES.end() - ? it->second - : DEVICES.emplace(CPU_KEY.i64, std::make_shared()).first->second; - } - auto key = DeviceKey{{type, card}}.i64; - if (auto it = DEVICES.find(key); it != DEVICES.end()) { - return it->second; - } - auto device = - type == Device::Type::Nvidia ? std::make_shared(card) - : UNREACHABLEX(Arc, ""); - return DEVICES.emplace(key, std::move(device)).first->second; + if (auto device = fetch(type, card); device) { return device; } + + using T = Device::Type; + // clang-format off + auto device = type == T::Nvidia ? std::make_shared(card) + : UNREACHABLEX(Arc, ""); + // clang-format on + auto [kind, ok] = DEVICES.try_emplace(static_cast(type)); + return kind->second.emplace(card, std::move(device)).first->second; } }// namespace refactor::hardware::device diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index 33a662d3..44b2deb8 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -54,7 +54,8 @@ namespace refactor::runtime { decltype(_outputsSize), graph_topo::GraphTopo, std::vector<_N>, - std::vector<_E>); + std::vector<_E>, + decltype(_device)); void setData(count_t, void const *, size_t); void setData(count_t, Arc); void getData(count_t, void *, size_t) const; diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index 167a8d6b..d2b4d5ac 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -31,7 +31,8 @@ namespace refactor::runtime { decltype(_outputsSize) outputs, graph_topo::GraphTopo topology, std::vector<_N> routines, - std::vector<_E> offsets) + std::vector<_E> offsets, + decltype(_device) device) : _resources(std::move(resources)), _stackSize(stack), _outputsSize(std::move(outputs)), @@ -40,7 +41,7 @@ namespace refactor::runtime { std::move(routines), std::move(offsets), }), - _device(device::fetch(Device::Type::Cpu, 0)), + _device(std::move(device)), _stack(nullptr) {} void Stream::setData(count_t i, void const *data, size_t size) { @@ -56,6 +57,7 @@ namespace refactor::runtime { } void Stream::dispatch(Arc device) { + ASSERT(device->type() == _device->type(), "Dispatching to heterogeneous device is not supported"); if (_device.get() == device.get()) { return; } diff --git a/src/04kernel/include/kernel/graph.h b/src/04kernel/include/kernel/graph.h index 425b33cd..071638d0 100644 --- a/src/04kernel/include/kernel/graph.h +++ b/src/04kernel/include/kernel/graph.h @@ -34,10 +34,14 @@ namespace refactor::kernel { using _E = Edge; using _G = graph_topo::Graph<_N, _E>; + Arc _device; _G _internal; public: - Graph(graph_topo::GraphTopo, std::vector<_N>, std::vector<_E>) noexcept; + Graph(decltype(_device), + graph_topo::GraphTopo, + std::vector<_N>, + std::vector<_E>) noexcept; runtime::Stream lower(Allocator) const; }; diff --git a/src/04kernel/src/graph.cc b/src/04kernel/src/graph.cc index d2a6e607..c66a28ff 100644 --- a/src/04kernel/src/graph.cc +++ b/src/04kernel/src/graph.cc @@ -2,10 +2,12 @@ namespace refactor::kernel { - Graph::Graph(graph_topo::GraphTopo topology, + Graph::Graph(decltype(_device) device, + graph_topo::GraphTopo topology, std::vector<_N> nodes, std::vector<_E> edges) noexcept - : _internal(graph_topo::Graph<_N, _E>{ + : _device(std::move(device)), + _internal(graph_topo::Graph<_N, _E>{ std::move(topology), std::move(nodes), std::move(edges), @@ -46,7 +48,8 @@ namespace refactor::kernel { std::move(outputs_), _internal.topology, std::move(nodes), - std::move(edgeOffsets)); + std::move(edgeOffsets), + _device); } }// namespace refactor::kernel diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index 0e4a558e..983a3d76 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -48,7 +48,7 @@ namespace refactor::computation { nodes[nodeIdx].kernel = std::move(candidates.front()); } - auto device = hardware::device::fetch(hardware::Device::Type::Cpu, 0); + auto device = hardware::device::fetch(target); for (auto i : range0_(edges.size())) { auto const &[tensor, name] = graph.edges[i]; if (!tensor || identities.contains(i)) { @@ -65,7 +65,7 @@ namespace refactor::computation { auto modifier = graph_topo::InplaceModifier(graph.topology); modifier.reconnect(identities); - return kernel::Graph(modifier.take(), std::move(nodes), std::move(edges)); + return kernel::Graph(std::move(device), modifier.take(), std::move(nodes), std::move(edges)); } auto Graph::internal() const -> decltype(_internal) const & { return _internal; }