Skip to content

Commit

Permalink
feat(utils): 全部格式化功能移动到 rust utilities
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 4, 2023
1 parent 974cc9b commit 0bafaad
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 77 deletions.
62 changes: 0 additions & 62 deletions scripts/format.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/08communication/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace refactor::communication {
using namespace frontend;

void register_() {
// clang-format off
// clang-format off
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
REGISTER(AllReduceAvg , AllReduce);
REGISTER(AllReduceSum , AllReduce);
Expand Down
93 changes: 93 additions & 0 deletions utilities/src/format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::{
ffi::OsStr,
fs,
os::unix::ffi::OsStrExt,
path::Path,
process::{Child, Command},
};

use crate::proj_dir;

pub fn format() {
for mut ele in recur(&proj_dir().join("src")) {
let status = ele.wait().unwrap();
if !status.success() {
println!("{:?}", status);
}
}
}

fn recur(path: impl AsRef<Path>) -> Vec<Child> {
fs::read_dir(path)
.unwrap()
.into_iter()
.filter_map(|entry| entry.ok())
.map(|entry| {
let path = entry.path();
if path.is_dir() {
if path.ends_with("09python_ffi/pybind11") {
vec![]
} else {
recur(&path)
}
} else if let Some(child) = format_one(&path) {
vec![child]
} else {
vec![]
}
})
.flatten()
.collect::<Vec<_>>()
}

fn format_one(file: &Path) -> Option<Child> {
const C_STYLE_FILE: [&[u8]; 9] = [
b"h", b"hh", b"hpp", b"c", b"cc", b"cpp", b"cxx", b"cu", b"mlu",
];
let Some(ext) = file.extension() else {
return None;
};
if C_STYLE_FILE.contains(&ext.as_bytes()) {
Command::new("clang-format-14")
.arg("-i")
.arg(file)
.spawn()
.ok()
} else if ext == OsStr::new("py") {
Command::new("black").arg(file).spawn().ok()
} else {
None
}
}

// 根据 git diff 判断格式化哪些文件的功能,暂时没用
// if len(sys.argv) == 1:
// # Last commit.
// print("Formats git added files.")
// for line in (
// run("git status", cwd=proj_path, capture_output=True, shell=True)
// .stdout.decode()
// .splitlines()
// ):
// line = line.strip()
// # Only formats git added files.
// for pre in ["new file:", "modified:"]:
// if line.startswith(pre):
// format_file(line[len(pre) :].strip())
// break
// else:
// # Origin commit.
// origin = sys.argv[1]
// print(f'Formats changed files from "{origin}".')
// for line in (
// run(f"git diff {origin}", cwd=proj_path, capture_output=True, shell=True)
// .stdout.decode()
// .splitlines()
// ):
// diff = "diff --git "
// if line.startswith(diff):
// files = line[len(diff) :].split(" ")
// assert len(files) == 2
// assert files[0][:2] == "a/"
// assert files[1][:2] == "b/"
// format_file(files[1][2:])
7 changes: 4 additions & 3 deletions utilities/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{ffi::OsStr, fs, path::Path, process::Command};
use crate::proj_dir;
use std::{ffi::OsStr, fs, path::Path, process::Command};

const SCRIPT: &str = r#"
import numpy as np
Expand Down Expand Up @@ -46,15 +47,15 @@ model_path = Path(\"{}\").resolve()
)
}

pub fn infer(proj_dir: impl AsRef<Path>, path: impl AsRef<Path>) {
pub fn infer(path: impl AsRef<Path>) {
let path = path.as_ref();
assert!(
path.is_file() && path.extension() == Some(OsStr::new("onnx")),
"\"{}\" is not a onnx file",
path.display(),
);
Command::new("python")
.current_dir(proj_dir)
.current_dir(proj_dir())
.arg("-c")
.arg(format!("{}{}", model_path(path), SCRIPT))
.status()
Expand Down
20 changes: 15 additions & 5 deletions utilities/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod format;
mod infer;
mod make;

Expand All @@ -8,6 +9,7 @@ use std::{
io::ErrorKind,
path::{Path, PathBuf},
process::Command,
sync::OnceLock,
};

/// Refactor Graph utilities
Expand Down Expand Up @@ -39,35 +41,43 @@ enum Commands {
Clean,
/// run tests
Test,
/// format source files
Format,
/// run model inference
Infer { path: PathBuf },
}

pub fn proj_dir() -> &'static Path {
static PROJ: OnceLock<&Path> = OnceLock::new();
*PROJ.get_or_init(|| Path::new(std::env!("CARGO_MANIFEST_DIR")).parent().unwrap())
}

fn main() {
let proj_dir = Path::new(std::env!("CARGO_MANIFEST_DIR")).parent().unwrap();
match Cli::parse().command {
Commands::Make {
release,
install_python,
dev,
cxx_compiler,
} => make::make(proj_dir, release, install_python, dev, cxx_compiler),
} => make::make(release, install_python, dev, cxx_compiler),

Commands::Clean => match fs::remove_dir_all(proj_dir.join("build")) {
Commands::Clean => match fs::remove_dir_all(proj_dir().join("build")) {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::NotFound => {}
Err(e) => panic!("{}", e),
},

Commands::Test => {
Command::new("make")
.current_dir(proj_dir.join("build"))
.current_dir(proj_dir().join("build"))
.arg("test")
.arg("-j")
.status()
.unwrap();
}

Commands::Infer { path } => infer::infer(proj_dir, path),
Commands::Format => format::format(),

Commands::Infer { path } => infer::infer(path),
}
}
12 changes: 6 additions & 6 deletions utilities/src/make.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::{
use crate::proj_dir;
use std::{
collections::HashSet,
ffi::{OsStr, OsString},
fs,
path::{Path, PathBuf},
path::PathBuf,
process::Command,
};

Expand All @@ -13,7 +14,6 @@ enum Target {
}

pub fn make(
proj_dir: impl AsRef<Path>,
release: bool,
install_python: bool,
dev: Option<Vec<OsString>>,
Expand All @@ -40,7 +40,8 @@ pub fn make(
.collect::<HashSet<_>>();
let dev = |d: Target| if dev.contains(&d) { "ON" } else { "OFF" };

let build = proj_dir.as_ref().join("build");
let proj_dir = proj_dir();
let build = proj_dir.join("build");
fs::create_dir_all(&build).unwrap();

let mut cmd = Command::new("cmake");
Expand Down Expand Up @@ -74,15 +75,14 @@ pub fn make(
.unwrap()
.path();
let to = proj_dir
.as_ref()
.join("src/09python_ffi/src/refactor_graph")
.join(from.file_name().unwrap());
fs::copy(from, to).unwrap();

Command::new("pip")
.arg("install")
.arg("-e")
.arg(proj_dir.as_ref().join("src/09python_ffi/"))
.arg(proj_dir.join("src/09python_ffi/"))
.status()
.unwrap();
}
Expand Down

0 comments on commit 0bafaad

Please sign in to comment.