Skip to content

Commit

Permalink
test: 添加环境变量控制硬件选项,合并 CPU 单卡和分布式版本测试
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 5, 2024
1 parent 9ae4f93 commit 917fcf4
Show file tree
Hide file tree
Showing 13 changed files with 32 additions and 119 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ test-utils.path = "test-utils"

ggus = "0.3"
itertools = "0.13"
regex = "1.11"
env_logger = "0.11"
build-script-cfg = "0.0"

ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "f1fdd24" }
Expand Down
1 change: 1 addition & 0 deletions models/llama/common-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ operators = { workspace = true, features = ["common-cpu"] }
[dev-dependencies]
test-utils.workspace = true
gguf.workspace = true
regex.workspace = true
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use operators::{
random_sample::{KVPair, SampleArgs},
Blob,
};
use regex::Regex;
use std::{
iter::zip,
ptr::copy_nonoverlapping,
Expand All @@ -22,9 +23,10 @@ use test_utils::{Inference, TokenizerAndPrompt};
type Worker<'w> = LlamaWorker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>;

#[test]
fn test_dist() {
fn test_infer() {
let Some(Inference {
model,
devices,
prompt,
as_user,
temperature,
Expand All @@ -49,7 +51,17 @@ fn test_dist() {
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
println!("{sample_args:?}");

let lens = [1; 4];
let lens = match devices {
Some(devices) => {
let regex = Regex::new(r"\d+").unwrap();
regex
.find_iter(&devices)
.map(|c| c.as_str().parse::<usize>().unwrap())
.collect::<Vec<_>>()
}
None => vec![1],
};
println!("distribution: {lens:?}");
let count = lens.iter().sum();
let (seeds, senders) = WorkerSeed::new(lens.len());
thread::scope(|s| {
Expand Down
5 changes: 1 addition & 4 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,4 @@ impl WeightLoader for Weights<'_> {
}

#[cfg(test)]
mod test_infer;

#[cfg(test)]
mod test_dist;
mod infer;
109 changes: 0 additions & 109 deletions models/llama/common-cpu/src/test_infer.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ fn test_infer() {
top_p,
top_k,
max_steps,
..
}) = Inference::load()
else {
return;
Expand Down
2 changes: 1 addition & 1 deletion models/llama/infini/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ impl WeightLoader for Weights {
}

#[cfg(test)]
mod test_infer;
mod infer;
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ fn test_infer() {
top_p,
top_k,
max_steps,
..
}) = Inference::load()
else {
return;
};

let roll_cache_size = load_roll_cache_size();
println!("roll_cache_size: {}", roll_cache_size);
println!("roll_cache_size: {roll_cache_size}");
let gguf = GGufModel::read(model.iter().map(|s| &**s));

let TokenizerAndPrompt {
Expand Down
2 changes: 1 addition & 1 deletion models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,4 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
}

#[cfg(test)]
mod test_infer;
mod infer;
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ fn test_infer() {
top_p,
top_k,
max_steps,
..
}) = Inference::load()
else {
return;
Expand Down
2 changes: 1 addition & 1 deletion models/llama/opencl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ impl WeightLoader for Weights {
}

#[cfg(test)]
mod test_infer;
mod infer;
1 change: 1 addition & 0 deletions test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]

[dependencies]
gguf.workspace = true
env_logger.workspace = true
cli-table = "0.4.9"
6 changes: 6 additions & 0 deletions test-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use std::{
fmt,
path::{Path, PathBuf},
str::FromStr,
sync::Once,
time::{Duration, Instant},
vec,
};

pub struct Inference {
pub model: Box<[Mmap]>,
pub devices: Option<String>,
pub prompt: String,
pub as_user: bool,
pub temperature: f32,
Expand All @@ -23,6 +25,9 @@ pub struct Inference {

impl Inference {
pub fn load() -> Option<Self> {
static ONCE: Once = Once::new();
ONCE.call_once(|| env_logger::init());

Check warning

Code scanning / clippy

redundant closure Warning test

redundant closure

let Some(path) = var_os("TEST_MODEL") else {
println!("TEST_MODEL not set");
return None;
Expand All @@ -42,6 +47,7 @@ impl Inference {

Some(Self {
model: map_files(path),
devices: var("DEVICES").ok(),
prompt: var("PROMPT").unwrap_or_else(|_| String::from("Once upon a time,")),
as_user: var("AS_USER").ok().is_some_and(|s| !s.is_empty()),
temperature: parse("TEMPERATURE", 0.),
Expand Down

0 comments on commit 917fcf4

Please sign in to comment.