Skip to content

Commit

Permalink
todo(iluvatar): 适配天数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 3, 2025
1 parent ae79579 commit 315a1e3
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 68 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "807ea2b", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "94a081e", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "5aec761" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "b320cd9" }
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "b320cd9" }
51 changes: 14 additions & 37 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,6 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
pub const fn meta(&self) -> &LlamaMeta {
&self.meta
}

pub fn workspace_size(&self, nt: usize, max_seq_len: usize, max_att_len: usize) -> usize {
let LlamaMeta {
nh,
nkvh,
nexp,
dh,
di,
..
} = self.meta;

let embd = self.meta.embd(nt);
let dt = embd.dt();
let embd = embd.take();

let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take();
let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take();
let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take();

if self.meta.is_moe() {
let routes = Tensor::new(dt, &[nt, nexp]).take();
let gate_up = Tensor::new(dt, &[1, di * 2]).take();
embd + (qkv + q + att).max(routes).max(gate_up)
} else {
let gate_up = Tensor::new(dt, &[nt, di * 2]).take();
embd + (qkv + q + att).max(gate_up)
}
}
}

impl<Ops, W> LlamaWorker<Ops, W>
Expand All @@ -170,7 +142,7 @@ where
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let Args {
embd,
embd: mut x,
sin_cos,
mut logits,
mut requests,
Expand All @@ -188,18 +160,23 @@ where
..
} = self.meta;

let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
let tensor = |shape: &[usize]| Tensor::new(x.dt(), shape);
let x1 = tensor(x.shape());
let qkv = tensor(&[nt, (nh + nkvh + nkvh) * dh]);
let q = tensor(&[max_seq_len, nh, dh]).take();
let att = tensor(&[nh, max_seq_len, max_att_len]).take();
let gate_up = tensor(&[if self.meta.is_moe() { 1 } else { nt }, di * 2]);
let routes = tensor(&[nt, nexp]);

let workspace_size = *x1.get()
+ (*qkv.get() + q + att)
.max(*routes.get())
.max(*gate_up.get());

let mut x = embd;
let x1 = Tensor::new(x.dt(), x.shape());
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
let (buf, workspace) = workspace.split_at_mut(*x1.get());
let mut x1 = x1.map(|_| buf);

let qkv = Tensor::new(x.dt(), &[nt, (nh + nkvh + nkvh) * dh]);
let gate_up = Tensor::new(x.dt(), &[if self.meta.is_moe() { 1 } else { nt }, di * 2]);
let routes = Tensor::new(x.dt(), &[nt, nexp]);

let sin = sin_cos.clone().index(0, 0);
let cos = sin_cos.index(0, 1);

Expand Down
3 changes: 2 additions & 1 deletion models/llama/nvidia-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
llama.path = "../common"
common.workspace = true
log.workspace = true
operators = { workspace = true, features = ["nvidia-gpu"] }
operators = { workspace = true, features = ["nvidia-gpu", "iluvatar-gpu"] }

[build-dependencies]
build-script-cfg.workspace = true
search-cuda-tools.workspace = true
search-corex-tools.workspace = true

[dev-dependencies]
test-utils = { workspace = true, features = ["llama"] }
Expand Down
20 changes: 15 additions & 5 deletions models/llama/nvidia-gpu/build.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
fn main() {
use build_script_cfg::Cfg;
use search_corex_tools::find_corex;
use search_cuda_tools::{find_cuda_root, find_nccl_root};

let driver = Cfg::new("driver_detected");
let nccl = Cfg::new("nccl_detected");
if find_cuda_root().is_some() {
driver.define();
let nvidia = Cfg::new("use_nvidia");
let iluvatar = Cfg::new("use_iluvatar");
let nccl = Cfg::new("use_nccl");

let nvidia_detected = find_cuda_root().is_some();
let iluvatar_detected = find_corex().is_some();

if nvidia_detected {
nvidia.define();
if find_nccl_root().is_some() {
nccl.define();
nccl.define()
}
}

if iluvatar_detected {
iluvatar.define()
}
}
40 changes: 27 additions & 13 deletions models/llama/nvidia-gpu/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::{Operators, RandomSample, Weights};
use crate::{Operators, RandomSample, Weights};
use gguf::GGufModel;
use llama::{
ext::ggml_quants::f16, LlamaArgs, LlamaMeta, LlamaRequest, LlamaStorage, LlamaWorker, Tensor,
};
use operators::{
cuda::{self, memcpy_d2h, Device, NoDevice},
nvidia_gpu::{Config, Gpu},
cuda::{self, memcpy_d2h, Config, Device, Gpu, MemPoolBlob, NoDevice, StreamMemPool},
random_sample::{KVPair, SampleArgs},
Alloc, QueueAlloc,
};
use std::{slice::from_raw_parts_mut, time::Instant, usize};
use test_utils::{load_roll_cache_size, Inference, TokenizerAndPrompt};
Expand Down Expand Up @@ -69,23 +69,30 @@ fn test_infer() {
let stream = ctx.stream();

let time = Instant::now();
let token_embd = stream.from_host(model.token_embd);
let token_embd = stream.ctx().from_host(model.token_embd);
let weights = Weights::new(&model, .., 1, roll_cache_size, ctx);
println!("load weights: {:?}", time.elapsed());

let (free, _) = ctx.mem_info();
let queue_alloc = StreamMemPool::new(stream);
queue_alloc.put((free.0 >> 30) << 30);

let alloc = |size| -> MemPoolBlob { queue_alloc.alloc(size) };

let mut worker = Worker::new(0, &gpu, meta.clone(), weights);
let mut cache = meta.kv_cache(nctx).map(|size| stream.malloc::<u8>(size));
let sin_cos = <Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &stream);
let indices = RandomSample::build_indices(nvoc, &stream);
let mut cache = meta.kv_cache(nctx).map(alloc);
let sin_cos =
<Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &queue_alloc);
let indices = RandomSample::build_indices(nvoc, &queue_alloc);
let sample = RandomSample::new(gpu);

test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
let mut embd = meta.embd(input.len()).map(|len| stream.malloc::<u8>(len));
let mut logits = meta.logits(1).map(|len| stream.malloc::<u8>(len));
let mut embd = meta.embd(input.len()).map(alloc);
let mut logits = meta.logits(1).map(alloc);

let d = embd.get().len() / input.len();
for (i, &tok) in input.iter().enumerate() {
stream.memcpy_d2d(
queue_alloc.queue().memcpy_d2d(
&mut embd.get_mut()[i * d..][..d],
&token_embd[tok as usize * d..][..d],
)
Expand All @@ -108,14 +115,21 @@ fn test_infer() {
max_att_len: pos + input.len(),
},
&mut [],
&stream,
&queue_alloc,
)
.unwrap();

let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
let mut pairs = Tensor::kv_pair_vec(1, alloc);

sample
.launch(&mut pairs, &logits, &indices, sample_args, &mut [], &stream)
.launch(
&mut pairs,
&logits,
&indices,
sample_args,
&mut [],
&queue_alloc,
)
.unwrap();

let mut pair = KVPair::new(0, f16::ZERO);
Expand Down
13 changes: 6 additions & 7 deletions models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#![cfg(driver_detected)]
#![cfg(any(use_nvidia, use_iluvatar))]

use common::{Contiguous, Slab};
use llama::{BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader};
use log::trace;
use operators::{
all_reduce::{AllReduce, NonAllReduce},
cuda::{memcpy_d2h, AsRaw, CurrentCtx, DevByte, DevMem, Event, HostMem, Stream},
nvidia_gpu::Gpu,
random_sample::nvidia_gpu::Operator as RandomSampleGpu,
rearrange::nvidia_gpu::Operator as Rearrange,
cuda::{memcpy_d2h, AsRaw, CurrentCtx, DevByte, DevMem, Event, Gpu, HostMem, Stream},
random_sample::cuda::Operator as RandomSampleGpu,
rearrange::cuda::Operator as Rearrange,
ByteOf, QueueOf, TopoNode,
};
use std::{
Expand Down Expand Up @@ -119,7 +118,7 @@ impl Drop for WeightResult<'_, '_> {

macro_rules! op {
($name:ident) => {
operators::$name::nvidia_gpu::Operator
operators::$name::cuda::Operator
};
}

Expand Down Expand Up @@ -378,5 +377,5 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
#[cfg(test)]
mod infer;

#[cfg(all(test, nccl_detected))]
#[cfg(all(test, use_nccl))]
mod nccl_parallel;
5 changes: 2 additions & 3 deletions models/llama/nvidia-gpu/src/nccl_parallel.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::{Operators, RandomSample, Weights};
use crate::{Operators, RandomSample, Weights};
use gguf::GGufModel;
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
use log::info;
use operators::{
all_reduce::nccl::Operator as AllReduce,
cuda::{self, memcpy_d2h, NoDevice},
cuda::{self, memcpy_d2h, NcclNode, NoDevice},
nccl::CommunicatorGroup,
nvidia_gpu::NcclNode,
random_sample::{KVPair, SampleArgs},
TopoNode,
};
Expand Down

0 comments on commit 315a1e3

Please sign in to comment.