Skip to content

Commit

Permalink
fix(hardware): Stream 禁止调度到异构设备
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 28, 2023
1 parent 16e24fd commit 4aa1efb
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 39 deletions.
1 change: 1 addition & 0 deletions src/02hardware/include/hardware/device_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

namespace refactor::hardware::device {

Arc<Device> fetch(Device::Type);
Arc<Device> fetch(Device::Type, int32_t card);
Arc<Device> init(Device::Type, int32_t card, std::string_view args);

Expand Down
62 changes: 32 additions & 30 deletions src/02hardware/src/device_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,43 @@

namespace refactor::hardware::device {

static std::unordered_map<int64_t, Arc<Device>> DEVICES;
union DeviceKey {
struct {
Device::Type type;
int32_t card;
};
int64_t i64;
};
constexpr static DeviceKey CPU_KEY{{Device::Type::Cpu, 0}};
static Arc<Device> cpu() {
static auto cpu = std::make_shared<Cpu>();
return cpu;
}
static std::unordered_map<int32_t, std::unordered_map<int32_t, Arc<Device>>> DEVICES;
constexpr static auto CPU_KEY = static_cast<int32_t>(Device::Type::Cpu);

Arc<Device> fetch(Device::Type type) {
auto type_ = static_cast<int32_t>(type);
if (type_ == CPU_KEY) { return cpu(); }
if (auto kind = DEVICES.find(type_); kind != DEVICES.end()) {
if (auto it = kind->second.begin(); it != kind->second.end()) {
return it->second;
}
}
return init(type, 0, "");
}
Arc<Device> fetch(Device::Type type, int32_t card) {
if (type == decltype(type)::Cpu) {
auto it = DEVICES.find(CPU_KEY.i64);
return it != DEVICES.end()
? it->second
: DEVICES.emplace(CPU_KEY.i64, std::make_shared<Cpu>()).first->second;
auto type_ = static_cast<int32_t>(type);
if (type_ == CPU_KEY) { return cpu(); }
if (auto kind = DEVICES.find(type_); kind != DEVICES.end()) {
if (auto it = kind->second.find(card); it != kind->second.end()) {
return it->second;
}
}
auto it = DEVICES.find(DeviceKey{{type, card}}.i64);
return it != DEVICES.end() ? it->second : nullptr;
return nullptr;
}
Arc<Device> init(Device::Type type, int32_t card, std::string_view args) {
if (type == decltype(type)::Cpu) {
auto it = DEVICES.find(CPU_KEY.i64);
return it != DEVICES.end()
? it->second
: DEVICES.emplace(CPU_KEY.i64, std::make_shared<Cpu>()).first->second;
}
auto key = DeviceKey{{type, card}}.i64;
if (auto it = DEVICES.find(key); it != DEVICES.end()) {
return it->second;
}
auto device =
type == Device::Type::Nvidia ? std::make_shared<Nvidia>(card)
: UNREACHABLEX(Arc<Device>, "");
return DEVICES.emplace(key, std::move(device)).first->second;
if (auto device = fetch(type, card); device) { return device; }

using T = Device::Type;
// clang-format off
auto device = type == T::Nvidia ? std::make_shared<Nvidia>(card)
: UNREACHABLEX(Arc<Device>, "");
// clang-format on
auto [kind, ok] = DEVICES.try_emplace(static_cast<int32_t>(type));
return kind->second.emplace(card, std::move(device)).first->second;
}

}// namespace refactor::hardware::device
3 changes: 2 additions & 1 deletion src/03runtime/include/runtime/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ namespace refactor::runtime {
decltype(_outputsSize),
graph_topo::GraphTopo,
std::vector<_N>,
std::vector<_E>);
std::vector<_E>,
decltype(_device));
void setData(count_t, void const *, size_t);
void setData(count_t, Arc<hardware::Device::Blob>);
void getData(count_t, void *, size_t) const;
Expand Down
6 changes: 4 additions & 2 deletions src/03runtime/src/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace refactor::runtime {
decltype(_outputsSize) outputs,
graph_topo::GraphTopo topology,
std::vector<_N> routines,
std::vector<_E> offsets)
std::vector<_E> offsets,
decltype(_device) device)
: _resources(std::move(resources)),
_stackSize(stack),
_outputsSize(std::move(outputs)),
Expand All @@ -40,7 +41,7 @@ namespace refactor::runtime {
std::move(routines),
std::move(offsets),
}),
_device(device::fetch(Device::Type::Cpu, 0)),
_device(std::move(device)),
_stack(nullptr) {}

void Stream::setData(count_t i, void const *data, size_t size) {
Expand All @@ -56,6 +57,7 @@ namespace refactor::runtime {
}

void Stream::dispatch(Arc<hardware::Device> device) {
ASSERT(device->type() == _device->type(), "Dispatching to heterogeneous device is not supported");
if (_device.get() == device.get()) {
return;
}
Expand Down
6 changes: 5 additions & 1 deletion src/04kernel/include/kernel/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ namespace refactor::kernel {
using _E = Edge;
using _G = graph_topo::Graph<_N, _E>;

Arc<hardware::Device> _device;
_G _internal;

public:
Graph(graph_topo::GraphTopo, std::vector<_N>, std::vector<_E>) noexcept;
Graph(decltype(_device),
graph_topo::GraphTopo,
std::vector<_N>,
std::vector<_E>) noexcept;
runtime::Stream lower(Allocator) const;
};

Expand Down
9 changes: 6 additions & 3 deletions src/04kernel/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

namespace refactor::kernel {

Graph::Graph(graph_topo::GraphTopo topology,
Graph::Graph(decltype(_device) device,
graph_topo::GraphTopo topology,
std::vector<_N> nodes,
std::vector<_E> edges) noexcept
: _internal(graph_topo::Graph<_N, _E>{
: _device(std::move(device)),
_internal(graph_topo::Graph<_N, _E>{
std::move(topology),
std::move(nodes),
std::move(edges),
Expand Down Expand Up @@ -46,7 +48,8 @@ namespace refactor::kernel {
std::move(outputs_),
_internal.topology,
std::move(nodes),
std::move(edgeOffsets));
std::move(edgeOffsets),
_device);
}

}// namespace refactor::kernel
4 changes: 2 additions & 2 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace refactor::computation {
nodes[nodeIdx].kernel = std::move(candidates.front());
}

auto device = hardware::device::fetch(hardware::Device::Type::Cpu, 0);
auto device = hardware::device::fetch(target);
for (auto i : range0_(edges.size())) {
auto const &[tensor, name] = graph.edges[i];
if (!tensor || identities.contains(i)) {
Expand All @@ -65,7 +65,7 @@ namespace refactor::computation {

auto modifier = graph_topo::InplaceModifier(graph.topology);
modifier.reconnect(identities);
return kernel::Graph(modifier.take(), std::move(nodes), std::move(edges));
return kernel::Graph(std::move(device), modifier.take(), std::move(nodes), std::move(edges));
}

auto Graph::internal() const -> decltype(_internal) const & { return _internal; }
Expand Down

0 comments on commit 4aa1efb

Please sign in to comment.