Skip to content

Commit

Permalink
fix (distributed): maintain opset version
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 committed Nov 23, 2023
1 parent 01235d4 commit 3806ffb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 19 deletions.
9 changes: 5 additions & 4 deletions examples/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,18 @@ def run_model(executor, inputs, n=10):
for _ in range(n):
executor.run()
end = time.time()
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
if n > 0:
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
return outputs


def run_and_compare(name, executor, inputs):

results = np.load(f"{name}_results.npy")
outputs = run_model(executor, inputs, 0)
outputs = run_model(executor, inputs)
print("outputs abs mean:", abs(outputs).mean())
np.testing.assert_allclose(outputs, results, rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-5)

def load_inputs(name, compiler):
inputs = compiler.zero_inputs()
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/parallel_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,6 @@ def find_successor(op_type: str, idx: int, search_limit: int = 1):
tt = output.type.tensor_type
if tt.HasField("shape"):
tt.ClearField("shape")
model = helper.make_model(graph)
model = helper.make_model(graph,opset_imports=model.opset_import )
model = onnx.shape_inference.infer_shapes(model)
return model
2 changes: 1 addition & 1 deletion src/04kernel/src/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace refactor::kernel {
}
#ifdef USE_CUDA
case NvidiaGpu: {
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(50ul << 30, 256, cuda::BasicCudaMemManager::instance());
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(20ul << 30, 256, cuda::BasicCudaMemManager::instance());
return memPool;
}
#endif
Expand Down
28 changes: 15 additions & 13 deletions src/09python_ffi/src/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ namespace refactor::python_ffi {
std::string allocator,
std::vector<std::string> passes,
int deviceID) {

kernel::Target target_ = kernel::Target::Cpu;
if (target == "cpu") {
target_ = kernel::Target::Cpu;
} else if (target == "cuda") {
target_ = kernel::Target::NvidiaGpu;
#ifdef USE_CUDA
if (deviceID >= 0) {
kernel::cuda::setCudaDevice(deviceID);
}
#endif
} else {
UNREACHABLE();
}

_g.collectVariables();
std::vector<std::string_view> unknownVariables;
for (auto const &[_, v] : _g.variables()) {
Expand Down Expand Up @@ -63,19 +78,6 @@ namespace refactor::python_ffi {
computation.layoutPermute();
}

kernel::Target target_ = kernel::Target::Cpu;
if (target == "cpu") {
target_ = kernel::Target::Cpu;
} else if (target == "cuda") {
target_ = kernel::Target::NvidiaGpu;
#ifdef USE_CUDA
if (deviceID >= 0) {
kernel::cuda::setCudaDevice(deviceID);
}
#endif
} else {
UNREACHABLE();
}

auto kernel = computation.lower(target_);
auto stream = kernel.lower(allocator == "flat"
Expand Down

0 comments on commit 3806ffb

Please sign in to comment.