Skip to content

Commit

Permalink
style(llama): 整理和优化 MOE 模型结构
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 27, 2024
1 parent e14c588 commit 2896798
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 72 deletions.
22 changes: 22 additions & 0 deletions gguf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ impl GGufMetaMap for GGufModel<'_> {
self.meta_kvs.get(key).map(|kv| (kv.ty(), kv.value_bytes()))
}
}

mod macros {
#[macro_export]
macro_rules! meta {
($gguf:expr => $key:ident) => {
$gguf.$key().unwrap()
};
($gguf:expr => $key:ident; $default:expr) => {
match $gguf.$key() {
Ok(val) => val,
Err(gguf::GGufMetaError::NotExist) => $default,
Err(e) => panic!("failed to read meta: {e:?}"),
}
};
}
#[macro_export]
macro_rules! tensor {
($gguf:expr => $name:expr) => {
&$gguf.tensors[&*$name]
};
}
}
60 changes: 31 additions & 29 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ where
nh,
nkvh,
nexp,
nexp_use,
dh,
di,
..
Expand Down Expand Up @@ -230,6 +229,7 @@ where
});

let req_split = requests.iter().map(|req| req.seq_len).collect::<Vec<_>>();
let tok_split = vec![1; nt];

let queue = queue_alloc.queue();
for iblk in 0..nblk {
Expand Down Expand Up @@ -324,38 +324,19 @@ where

Ops::memcpy_d2h(&mut routes_host, routes_dev.get(), queue)
}
let ([], routes, []) = (unsafe { routes_host.align_to_mut::<f16>() }) else {
let ([], mut routes, []) = (unsafe { routes_host.align_to::<f16>() }) else {
unreachable!()
};

for itok in (0..nt).rev() {
// fused topk
let mut routes = routes[itok * nexp..][..nexp]
.iter()
.copied()
.enumerate()
.collect::<Vec<_>>();

routes.sort_unstable_by(|&(_, a), &(_, b)| b.total_cmp(&a));
let max = routes[0].1.to_f32();
let mut sum = 0.;
let mut moe_gate = vec![(0, 0.0f32); nexp_use];
for ((i, x), gate) in std::iter::zip(routes, &mut moe_gate) {
let softmax = (x.to_f32() - max).exp();
*gate = (i, softmax);
sum += softmax
}
for (_, x) in &mut moe_gate {
*x /= sum
}
// mlp
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
let mut gate_up = gate_up.clone().map(|_| buf);

let mut x = x.map_slice_mut().slice(0, itok, 0, 1);
let x1 = x1.map_slice_mut().slice(0, itok, 0, 1);
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
let mut gate_up = gate_up.clone().map(|_| buf);

for (iexp, kexp) in moe_gate {
let x = x.split(0, &tok_split);
let x1 = x1.split(0, &tok_split);
for (mut x, x1) in izip!(x, x1) {
let (line, tail) = routes.split_at(nexp);
routes = tail;
for (iexp, kexp) in self.topk_with_index(line) {
let w = self.weights.ffn_gate_up(iblk, iexp, queue);
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
drop(w);
Expand Down Expand Up @@ -409,6 +390,27 @@ where
Ops: Operators,
W: WeightLoader<Hardware = Ops::Hardware>,
{
fn topk_with_index(&self, line: &[f16]) -> Vec<(usize, f32)> {
let mut routes = line
.iter()
.map(|&x| x.to_f32())
.enumerate()
.collect::<Vec<_>>();
routes.sort_unstable_by(|&(_, a), &(_, b)| b.total_cmp(&a));
routes.truncate(self.meta.nexp_use);
// standard softmax
let (_, max) = routes[0];
let mut sum = 0.;
for (_, x) in &mut routes {
*x = (*x - max).exp();
sum += *x
}
for (_, x) in &mut routes {
*x /= sum
}
routes
}

fn rms_norm<Y, X, W_, QA>(
&self,
y: &mut Tensor<Y>,
Expand Down
69 changes: 26 additions & 43 deletions models/llama/common/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,61 +26,44 @@ pub struct BlkStorage<T> {

impl<'a> Storage<&'a [u8]> {
pub fn from_gguf(gguf: &GGufModel<'a>) -> Self {
macro_rules! meta {
($key:ident) => {
gguf.$key().unwrap()
};
($key:ident; $default:expr) => {
match gguf.$key() {
Ok(val) => val,
Err(gguf::GGufMetaError::NotExist) => $default,
Err(e) => panic!("failed to read meta: {e:?}"),
}
};
}
macro_rules! tensor {
($name:expr) => {
&gguf.tensors[&*$name]
};
}

let token_embd = tensor!["token_embd.weight"];
let output_norm = tensor!["output_norm.weight"];
let qkv0 = tensor!["blk.0.attn_qkv.weight"];
use gguf::{meta, tensor};
let token_embd = tensor![gguf => "token_embd.weight"];
let output_norm = tensor![gguf => "output_norm.weight"];
let qkv0 = tensor![gguf => "blk.0.attn_qkv.weight"];
#[rustfmt::skip]
let meta = LlamaMeta {
dt_embd : token_embd.ty,
dt_norm : output_norm.ty,
dt_mat : qkv0.ty,

nblk : meta!(llm_block_count ),
nctx : meta!(llm_context_length ),
nvoc : meta!(tokenizer_ggml_tokens).len(),
nh : meta!(llm_attention_head_count ),
nkvh : meta!(llm_attention_head_count_kv),
nexp : meta!(llm_expert_count ; 0),
nexp_use: meta!(llm_expert_used_count ; 0),
d : meta!(llm_embedding_length ),
dh : meta!(llm_rope_dimension_count ),
di : meta!(llm_feed_forward_length ),

epsilon : meta!(llm_attention_layer_norm_rms_epsilon; 1e-5),
theta : meta!(llm_rope_freq_base ; 1e4 ),
nblk : meta!(gguf => llm_block_count ),
nctx : meta!(gguf => llm_context_length ),
nvoc : meta!(gguf => tokenizer_ggml_tokens).len(),
nh : meta!(gguf => llm_attention_head_count ),
nkvh : meta!(gguf => llm_attention_head_count_kv),
nexp : meta!(gguf => llm_expert_count ; 0),
nexp_use: meta!(gguf => llm_expert_used_count ; 0),
d : meta!(gguf => llm_embedding_length ),
dh : meta!(gguf => llm_rope_dimension_count ),
di : meta!(gguf => llm_feed_forward_length ),

epsilon : meta!(gguf => llm_attention_layer_norm_rms_epsilon; 1e-5),
theta : meta!(gguf => llm_rope_freq_base ; 1e4 ),
};

#[rustfmt::skip]
let blocks = (0..meta.nblk)
.map(|i| BlkStorage {
attn_norm: tensor![format!("blk.{i}.attn_norm.weight" )].data,
attn_qkv: tensor![format!("blk.{i}.attn_qkv.weight" )].data,
attn_o: tensor![format!("blk.{i}.attn_output.weight")].data,
ffn_norm: tensor![format!("blk.{i}.ffn_norm.weight" )].data,
attn_norm: tensor![gguf => format!("blk.{i}.attn_norm.weight" )].data,
attn_qkv: tensor![gguf => format!("blk.{i}.attn_qkv.weight" )].data,
attn_o: tensor![gguf => format!("blk.{i}.attn_output.weight")].data,
ffn_norm: tensor![gguf => format!("blk.{i}.ffn_norm.weight" )].data,
ffn_gate_inp: if !meta.is_moe() { None }
else { Some(tensor![format!("blk.{i}.ffn_gate_inp.weight" )].data) },
ffn_gate_up : if !meta.is_moe() { tensor![format!("blk.{i}.ffn_gate_up.weight" )].data }
else { tensor![format!("blk.{i}.ffn_gate_up_exps.weight")].data },
ffn_down : if !meta.is_moe() { tensor![format!("blk.{i}.ffn_down.weight" )].data }
else { tensor![format!("blk.{i}.ffn_down_exps.weight" )].data },
else { Some(tensor![gguf => format!("blk.{i}.ffn_gate_inp.weight" )].data) },
ffn_gate_up : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_gate_up.weight" )].data }
else { tensor![gguf => format!("blk.{i}.ffn_gate_up_exps.weight")].data },
ffn_down : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_down.weight" )].data }
else { tensor![gguf => format!("blk.{i}.ffn_down_exps.weight" )].data },
})
.collect();

Expand Down

0 comments on commit 2896798

Please sign in to comment.