Skip to content

Commit

Permalink
todo: 初步实现 nv 分布式推理,但是报错
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 5, 2024
1 parent b59edbf commit 74cb126
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ env_logger = "0.11"
build-script-cfg = "0.0"

ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "f1fdd24" }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "923949f", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "b9e6fdd", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "136c30b" }
Expand Down
84 changes: 35 additions & 49 deletions models/llama/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use gguf::GGufModel;
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
use operators::{
all_reduce::common_cpu::Operator as AllReduce,
common_cpu::{Cpu, InprocNode, ThisThread},
common_cpu::{InprocNode, ThisThread},
random_sample::{KVPair, SampleArgs},
Blob,
};
Expand All @@ -12,10 +12,7 @@ use std::{
iter::zip,
ptr::copy_nonoverlapping,
slice::from_raw_parts_mut,
sync::{
mpsc::{Receiver, Sender},
Arc, Barrier,
},
sync::mpsc::{Receiver, Sender},
thread,
};
use test_utils::{Inference, TokenizerAndPrompt};
Expand Down Expand Up @@ -52,13 +49,11 @@ fn test_infer() {
println!("{sample_args:?}");

let lens = match devices {
Some(devices) => {
let regex = Regex::new(r"\d+").unwrap();
regex
.find_iter(&devices)
.map(|c| c.as_str().parse::<usize>().unwrap())
.collect::<Vec<_>>()
}
Some(devices) => Regex::new(r"\d+")
.unwrap()
.find_iter(&devices)
.map(|c| c.as_str().parse::<usize>().unwrap())
.collect::<Vec<_>>(),
None => vec![1],
};
println!("distribution: {lens:?}");
Expand Down Expand Up @@ -87,25 +82,27 @@ fn test_infer() {
meta.dh,
&ThisThread,
);

let sample = RandomSample::new(&node);
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
let mut pair = KVPair::new(0, f16::ZERO);
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
});

for task in tasks {
let Task {
nt,
pos,
embd,
logits,
barrier,
next,
} = task;
let mut embd = meta.embd(nt).map(|size| {
let mut blob = Blob::new(size);
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) };
blob
});
let mut logits = if i == 0 {
meta.logits(1)
.map(|size| unsafe { from_raw_parts_mut(logits, size) })
} else {
meta.logits(0).map(|_| &mut [][..])
};
let mut logits = meta.logits(if i == 0 { 1 } else { 0 }).map(Blob::new);
worker
.launch(
llama::LlamaArgs {
Expand All @@ -126,17 +123,27 @@ fn test_infer() {
&ThisThread,
)
.unwrap();
barrier.wait();
if i == 0 {
sample
.launch(
&mut pairs,
&logits,
&indices,
sample_args,
&mut [],
&ThisThread,
)
.unwrap();
next.send(pair.idx() as _).unwrap()
}
}
}))
})
.collect::<Vec<_>>();

let sample = RandomSample::new(&Cpu);
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
let (next, next_recv) = std::sync::mpsc::channel();
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
let mut embd = model.meta.embd(input.len()).map(Blob::new);
let mut logits = model.meta.logits(1).map(Blob::new);

let d = embd.get().len() / input.len();
for (i, &tok) in input.iter().enumerate() {
Expand All @@ -145,49 +152,28 @@ fn test_infer() {
}
let embd = embd.take();

let barrier = Arc::new(Barrier::new(senders.len() + 1));
for sender in &senders {
sender
.send(Task {
nt: input.len(),
pos,
embd: embd.as_ptr(),
logits: logits.get_mut().as_mut_ptr(),
barrier: barrier.clone(),
next: next.clone(),
})
.unwrap();
}
barrier.wait();

let mut pair = KVPair::new(0, f16::ZERO);
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
from_raw_parts_mut(&mut pair as *mut _ as _, size_of_val(&pair))
});

sample
.launch(
&mut pairs,
&logits,
&indices,
sample_args,
&mut [],
&ThisThread,
)
.unwrap();

pair.idx() as _
next_recv.recv().unwrap()
});

drop(senders);
drop(senders)
})
}

struct Task {
nt: usize,
pos: usize,
embd: *const u8,
logits: *mut u8,
barrier: Arc<Barrier>,
next: Sender<u32>,
}

unsafe impl Send for Task {}
Expand Down
1 change: 1 addition & 0 deletions models/llama/nvidia-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ search-cuda-tools.workspace = true
[dev-dependencies]
test-utils.workspace = true
gguf.workspace = true
regex.workspace = true
6 changes: 5 additions & 1 deletion models/llama/nvidia-gpu/build.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
fn main() {
use build_script_cfg::Cfg;
use search_cuda_tools::find_cuda_root;
use search_cuda_tools::{find_cuda_root, find_nccl_root};

let cfg = Cfg::new("hw_detected");
let nccl = Cfg::new("nccl_detected");
if find_cuda_root().is_some() {
cfg.define();
if find_nccl_root().is_some() {
nccl.define();
}
}
}
6 changes: 3 additions & 3 deletions models/llama/nvidia-gpu/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ fn test_infer() {
else {
return;
};

let roll_cache_size = load_roll_cache_size();
println!("roll_cache_size: {roll_cache_size}");
let gguf = GGufModel::read(model.iter().map(|s| &**s));

let TokenizerAndPrompt {
Expand All @@ -45,6 +42,9 @@ fn test_infer() {
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
println!("{sample_args:?}");

let roll_cache_size = load_roll_cache_size();
println!("roll_cache_size: {roll_cache_size}");

let gpu = match cuda::init() {
Ok(()) => Device::new(0),
Err(NoDevice) => return,
Expand Down
3 changes: 3 additions & 0 deletions models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,6 @@ impl<'ctx> WeightLoader for Weights<'ctx> {

#[cfg(test)]
mod infer;

#[cfg(all(test, nccl_detected))]
mod nccl_parallel;
Loading

0 comments on commit 74cb126

Please sign in to comment.