Skip to content

Commit

Permalink
refactor(llama): 拆散 mlp,使用移除 mlp 的算子库,准备实现 moe
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 27, 2024
1 parent 69bc397 commit 8286ecb
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 48 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "7821269", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "91be1cc", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "041badf" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "5aec761" }
2 changes: 1 addition & 1 deletion models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
type MatMul = op!(mat_mul);
type Rope = op!(rope);
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type AllReduce = R;

Expand Down
85 changes: 43 additions & 42 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use operators::{
all_reduce::{self, AllReduce, ReduceOp},
attention_kv_cached::{self, AttnKVCached},
mat_mul::{self, MatMul},
mlp::{self, Mlp},
rearrange::{self, Rearrange},
rms_norm::{self, RmsNorm},
rope::{self, Rope, Seq, SinCosTable},
swiglu::{self, Swiglu},
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
};
use std::ops::{Deref, DerefMut};
Expand All @@ -21,7 +21,7 @@ pub trait Operators {
type MatMul: MatMul<Self::Hardware>;
type Rope: Rope<Self::Hardware>;
type AttnKVCached: AttnKVCached<Self::Hardware>;
type Mlp: Mlp<Self::Hardware>;
type Swiglu: Swiglu<Self::Hardware>;
type Rearrange: Rearrange<Self::Hardware>;
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;

Expand Down Expand Up @@ -80,7 +80,7 @@ pub struct LlamaWorker<Ops: Operators, W> {
mat_mul: Ops::MatMul,
rope: Ops::Rope,
attn_kv_cached: Ops::AttnKVCached,
mlp: Ops::Mlp,
swiglu: Ops::Swiglu,
rearrange: Ops::Rearrange,
all_reduce: Ops::AllReduce,
residual: bool,
Expand All @@ -103,7 +103,7 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
mat_mul: Ops::MatMul::new(processor),
rope: Ops::Rope::new(processor),
attn_kv_cached: Ops::AttnKVCached::new(processor),
mlp: Ops::Mlp::new(processor),
swiglu: Ops::Swiglu::new(processor),
rearrange: Ops::Rearrange::new(processor),
all_reduce: Ops::AllReduce::new(node),
residual,
Expand Down Expand Up @@ -162,24 +162,24 @@ where
max_att_len,
} = args;
let LlamaMeta {
dt_embd,
nblk,
nh,
nkvh,
dh,
di,
..
} = self.meta;
let beta = if self.residual { 1. } else { 0. };
let residual = if self.residual { 1. } else { 0. };

let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);

let mut x = embd;
let x1 = Tensor::new(dt_embd, x.shape());
let x1 = Tensor::new(x.dt(), x.shape());
let (buf, workspace) = workspace.split_at_mut(*x1.get());
let mut x1 = x1.map(|_| buf);

let qkv = Tensor::new(dt_embd, &[nt, (nh + nkvh + nkvh) * dh]);
let qkv = Tensor::new(x.dt(), &[nt, (nh + nkvh + nkvh) * dh]);

let sin = sin_cos.clone().index(0, 0);
let cos = sin_cos.index(0, 1);
Expand Down Expand Up @@ -250,26 +250,39 @@ where
req.pos,
workspace,
queue_alloc,
)?;
)?
}
}

let o = q.merge(1..3).unwrap();
let w = self.weights.attn_o(iblk, queue);
self.mat_mul(&mut x, beta, &o, &w, 1., workspace, queue_alloc)?;
drop(w);

self.all_reduce(&mut x, workspace, queue_alloc)?;
self.mat_mul(&mut x, residual, &o, &w, 1., workspace, queue_alloc)?
}
{
self.all_reduce(&mut x, workspace, queue_alloc)?;

if !self.meta.is_moe() {
let w = self.weights.ffn_norm(iblk, queue);
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
drop(w);

self.mlp(&mut x, &x1, iblk, self.residual, workspace, queue_alloc)?;
let gate_up = Tensor::new(x.dt(), &[nt, di * 2]);
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
let mut gate_up = gate_up.map(|_| buf);

self.all_reduce(&mut x, workspace, queue_alloc)?;
let w = self.weights.ffn_gate_up(iblk, queue);
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
drop(w);

split!(gate_up => gate, up; [di, di] @ 1);
let mut gate = gate;
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;

let w = self.weights.ffn_down(iblk, queue);
self.mat_mul(&mut x, residual, &gate, &w, 1., workspace, queue_alloc)?
} else {
todo!()
}
self.all_reduce(&mut x, workspace, queue_alloc)?
}
if logits.shape()[0] == 0 {
return Ok(());
Expand All @@ -285,9 +298,9 @@ where
if src != dst {
let src = unsafe { x.map_slice_static() }.index(0, src);
let mut dst = x.map_slice_mut().index(0, dst);
self.rearrange(&mut dst, &src, workspace, queue_alloc)?;
self.rearrange(&mut dst, &src, workspace, queue_alloc)?
}
dst += 1;
dst += 1
}
}
assert_eq!(dst, logits.shape()[0]);
Expand All @@ -296,7 +309,7 @@ where
{
let inplace = unsafe { x.map_slice_static() };
let w = self.weights.output_norm(queue);
self.rms_norm(&mut x, &inplace, &w, workspace, queue_alloc)?;
self.rms_norm(&mut x, &inplace, &w, workspace, queue_alloc)?
}
let w = self.weights.output(queue);
self.mat_mul(&mut logits, 0., &x, &w, 1., workspace, queue_alloc)
Expand Down Expand Up @@ -445,36 +458,24 @@ where
)
}

fn mlp<Y, X, QA>(
fn swiglu<Gate, Up, QA>(
&self,
y: &mut Tensor<Y>,
x: &Tensor<X>,
iblk: usize,
residual: bool,
gate: &mut Tensor<Gate>,
up: &Tensor<Up>,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
Y: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
X: Deref<Target = [ByteOf<Ops::Hardware>]>,
Gate: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
Up: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let queue = queue_alloc.queue();
let w_gate_up = self.weights.ffn_gate_up(iblk, queue);
let w_down = self.weights.ffn_down(iblk, queue);

self.mlp.launch(
&mlp::Args {
y_layout: y.layout(),
y_base: y.base_mut(),
x_layout: x.layout(),
x_base: x.base(),
w_gate_up_layout: w_gate_up.layout(),
w_gate_up_base: w_gate_up.base(),
w_down_layout: w_down.layout(),
w_down_base: w_down.base(),
down_alpha: 1.,
residual,
self.swiglu.launch(
&swiglu::Args {
gate_layout: gate.layout(),
gate_base: gate.base_mut(),
up_layout: up.layout(),
up_base: up.base(),
},
workspace,
queue_alloc,
Expand Down
2 changes: 1 addition & 1 deletion models/llama/infini/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
type MatMul = op!(mat_mul);
type Rope = op!(rope);
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type AllReduce = R;

Expand Down
2 changes: 1 addition & 1 deletion models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ where
type MatMul = op!(mat_mul);
type Rope = op!(rope);
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type AllReduce = R;

Expand Down
2 changes: 1 addition & 1 deletion models/llama/opencl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ where
type MatMul = op!(mat_mul);
type Rope = op!(rope);
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type AllReduce = R;

Expand Down

0 comments on commit 8286ecb

Please sign in to comment.