Skip to content

Commit

Permalink
refactor(utils): 复杂功能分散到不同文件
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 4, 2023
1 parent 6e1e10c commit 974cc9b
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 129 deletions.
62 changes: 62 additions & 0 deletions utilities/src/infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::{ffi::OsStr, fs, path::Path, process::Command};

const SCRIPT: &str = r#"
import numpy as np
from functools import reduce
from onnxruntime import InferenceSession
from onnx import load
from refactor_graph.onnx import make_compiler
compiler = make_compiler(
load(model_path.__str__(), load_external_data=False),
model_path.parent.__str__(),
)
executor = compiler.compile("cuda", "default", [])
inputs = compiler.zero_inputs()
for i, input in enumerate(inputs):
if input.dtype in [np.int64, np.int32]:
input[...] = np.random.randint(
0, reduce(lambda x, y: x * y, input.shape), input.shape
).astype(input.dtype)
elif input.dtype in [np.float32, np.float64]:
input[...] = np.random.random(input.shape)
else:
raise NotImplementedError
executor.set_input(i, input)
executor.run()
session = InferenceSession(model_path)
answer = session.run(
None,
{session.get_inputs()[i].name: input for i, input in enumerate(inputs)},
)
for i, ans in enumerate(answer):
print((executor.get_output(i) - ans).flatten())
"#;

fn model_path(path: impl AsRef<Path>) -> String {
format!(
"\
from pathlib import Path
model_path = Path(\"{}\").resolve()
",
fs::canonicalize(path).unwrap().display(),
)
}

pub fn infer(proj_dir: impl AsRef<Path>, path: impl AsRef<Path>) {
let path = path.as_ref();
assert!(
path.is_file() && path.extension() == Some(OsStr::new("onnx")),
"\"{}\" is not a onnx file",
path.display(),
);
Command::new("python")
.current_dir(proj_dir)
.arg("-c")
.arg(format!("{}{}", model_path(path), SCRIPT))
.status()
.unwrap();
}
136 changes: 7 additions & 129 deletions utilities/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod infer;
mod make;

use clap::{Parser, Subcommand};
use std::{
collections::HashSet,
ffi::{OsStr, OsString},
ffi::OsString,
fs,
io::ErrorKind,
path::{Path, PathBuf},
Expand Down Expand Up @@ -41,12 +43,6 @@ enum Commands {
Infer { path: PathBuf },
}

#[derive(PartialEq, Eq, Hash, Debug)]
enum Target {
Nvidia,
Baidu,
}

fn main() {
let proj_dir = Path::new(std::env!("CARGO_MANIFEST_DIR")).parent().unwrap();
match Cli::parse().command {
Expand All @@ -55,79 +51,14 @@ fn main() {
install_python,
dev,
cxx_compiler,
} => {
let release = if release { "Release" } else { "Debug" };
let dev = dev
.unwrap_or_default()
.into_iter()
.map(|d| d.to_ascii_lowercase())
.filter_map(|d| {
if d == OsStr::new("cuda") || d == OsStr::new("nvidia") {
Some(Target::Nvidia)
} else if d == OsStr::new("kunlun")
|| d == OsStr::new("kunlunxin")
|| d == OsStr::new("baidu")
{
Some(Target::Baidu)
} else {
eprintln!("warning: unknown device: {:?}", d);
None
}
})
.collect::<HashSet<_>>();
let dev = |d: Target| if dev.contains(&d) { "ON" } else { "OFF" };

let build = proj_dir.join("build");
fs::create_dir_all(&build).unwrap();

let mut cmd = Command::new("cmake");
cmd.current_dir(&proj_dir)
.arg("-Bbuild")
.arg(format!("-DCMAKE_BUILD_TYPE={release}"))
.arg(format!("-DUSE_CUDA={}", dev(Target::Nvidia)))
.arg(format!("-DUSE_KUNLUN={}", dev(Target::Baidu)));
if let Some(cxx_compiler) = cxx_compiler {
cmd.arg(format!("-DCMAKE_CXX_COMPILER={}", cxx_compiler.display()));
}
cmd.status().unwrap();

Command::new("make")
.current_dir(&build)
.arg("-j")
.status()
.unwrap();
} => make::make(proj_dir, release, install_python, dev, cxx_compiler),

if install_python {
let from = fs::read_dir(build.join("src/09python_ffi"))
.unwrap()
.filter_map(|ele| ele.ok())
.find(|entry| {
entry
.path()
.extension()
.filter(|&ext| ext == OsStr::new("so"))
.is_some()
})
.unwrap()
.path();
let to = proj_dir
.join("src/09python_ffi/src/refactor_graph")
.join(from.file_name().unwrap());
fs::copy(from, to).unwrap();

Command::new("pip")
.arg("install")
.arg("-e")
.arg(proj_dir.join("src/09python_ffi/"))
.status()
.unwrap();
}
}
Commands::Clean => match fs::remove_dir_all(proj_dir.join("build")) {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::NotFound => {}
Err(e) => panic!("{}", e),
},

Commands::Test => {
Command::new("make")
.current_dir(proj_dir.join("build"))
Expand All @@ -136,60 +67,7 @@ fn main() {
.status()
.unwrap();
}
Commands::Infer { path } => {
const SCRIPT: &str = r#"
import numpy as np
from functools import reduce
from onnxruntime import InferenceSession
from onnx import load
from refactor_graph.onnx import make_compiler
compiler = make_compiler(
load(model_path.__str__(), load_external_data=False),
model_path.parent.__str__(),
)
executor = compiler.compile("cuda", "default", [])
inputs = compiler.zero_inputs()
for i, input in enumerate(inputs):
if input.dtype in [np.int64, np.int32]:
input[...] = np.random.randint(
0, reduce(lambda x, y: x * y, input.shape), input.shape
).astype(input.dtype)
elif input.dtype in [np.float32, np.float64]:
input[...] = np.random.random(input.shape)
else:
raise NotImplementedError
executor.set_input(i, input)

executor.run()
session = InferenceSession(model_path)
answer = session.run(
None,
{session.get_inputs()[i].name: input for i, input in enumerate(inputs)},
)
for i, ans in enumerate(answer):
print((executor.get_output(i) - ans).flatten())
"#;
assert!(
path.is_file() && path.extension() == Some(OsStr::new("onnx")),
"\"{}\" is not a onnx file",
path.display(),
);
Command::new("python")
.current_dir(proj_dir)
.arg("-c")
.arg(format!(
"\
from pathlib import Path
model_path = Path(\"{}\").resolve()
{}",
fs::canonicalize(path).unwrap().display(),
SCRIPT
))
.status()
.unwrap();
}
Commands::Infer { path } => infer::infer(proj_dir, path),
}
}
89 changes: 89 additions & 0 deletions utilities/src/make.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use std::{
collections::HashSet,
ffi::{OsStr, OsString},
fs,
path::{Path, PathBuf},
process::Command,
};

#[derive(PartialEq, Eq, Hash, Debug)]
enum Target {
Nvidia,
Baidu,
}

pub fn make(
proj_dir: impl AsRef<Path>,
release: bool,
install_python: bool,
dev: Option<Vec<OsString>>,
cxx_compiler: Option<PathBuf>,
) {
let release = if release { "Release" } else { "Debug" };
let dev = dev
.unwrap_or_default()
.into_iter()
.map(|d| d.to_ascii_lowercase())
.filter_map(|d| {
if d == OsStr::new("cuda") || d == OsStr::new("nvidia") {
Some(Target::Nvidia)
} else if d == OsStr::new("kunlun")
|| d == OsStr::new("kunlunxin")
|| d == OsStr::new("baidu")
{
Some(Target::Baidu)
} else {
eprintln!("warning: unknown device: {:?}", d);
None
}
})
.collect::<HashSet<_>>();
let dev = |d: Target| if dev.contains(&d) { "ON" } else { "OFF" };

let build = proj_dir.as_ref().join("build");
fs::create_dir_all(&build).unwrap();

let mut cmd = Command::new("cmake");
cmd.current_dir(&proj_dir)
.arg("-Bbuild")
.arg(format!("-DCMAKE_BUILD_TYPE={release}"))
.arg(format!("-DUSE_CUDA={}", dev(Target::Nvidia)))
.arg(format!("-DUSE_KUNLUN={}", dev(Target::Baidu)));
if let Some(cxx_compiler) = cxx_compiler {
cmd.arg(format!("-DCMAKE_CXX_COMPILER={}", cxx_compiler.display()));
}
cmd.status().unwrap();

Command::new("make")
.current_dir(&build)
.arg("-j")
.status()
.unwrap();

if install_python {
let from = fs::read_dir(build.join("src/09python_ffi"))
.unwrap()
.filter_map(|ele| ele.ok())
.find(|entry| {
entry
.path()
.extension()
.filter(|&ext| ext == OsStr::new("so"))
.is_some()
})
.unwrap()
.path();
let to = proj_dir
.as_ref()
.join("src/09python_ffi/src/refactor_graph")
.join(from.file_name().unwrap());
fs::copy(from, to).unwrap();

Command::new("pip")
.arg("install")
.arg("-e")
.arg(proj_dir.as_ref().join("src/09python_ffi/"))
.status()
.unwrap();
}
}

0 comments on commit 974cc9b

Please sign in to comment.