Skip to content

Commit

Permalink
style(gpt2): 整理代码,与 clip 对齐
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 26, 2024
1 parent d09838b commit 430044c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 100 deletions.
1 change: 0 additions & 1 deletion gguf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pub struct GGufModel<'a> {

/// GGuf 张量。
#[derive(Clone, Debug)]
#[allow(missing_docs)]
pub struct GGufTensor<'a> {
pub ty: DigitLayout,
pub shape: Box<[usize]>,
Expand Down
66 changes: 24 additions & 42 deletions models/gpt2/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub trait Operators {
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
type AddRows: AddRows<Self::Hardware>;
type Mlp: Gpt2Mlp<Self::Hardware>;

fn debug<T>(tensor: &Tensor<T>)
where
T: Deref<Target = [ByteOf<Self::Hardware>]>;
Expand Down Expand Up @@ -66,6 +67,7 @@ pub struct Gpt2Worker<Ops: Operators, W> {
all_reduce: Ops::AllReduce,
add_rows: Ops::AddRows,
mlp: Ops::Mlp,
pub debug: bool,
}

impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
Expand All @@ -81,6 +83,7 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
all_reduce: Ops::AllReduce::new(node),
add_rows: Ops::AddRows::new(processor),
mlp: Ops::Mlp::new(processor),
debug: true,
}
}

Expand Down Expand Up @@ -136,7 +139,6 @@ where
idx,
idx_add,
} = args;

let Gpt2Meta {
dt_embd,
nblk,
Expand All @@ -145,6 +147,7 @@ where
dh,
..
} = self.meta;

let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
let queue = queue_alloc.queue();
Expand All @@ -161,7 +164,7 @@ where
token_embd = token_embd.merge(0..2).unwrap();
}
let mut x = token_embd;
let x1 = Tensor::new(dt_embd, x.shape());
let x1 = Tensor::new(x.dt(), x.shape());
let (buf, workspace) = workspace.split_at_mut(*x1.get());
let mut x1 = x1.map(|_| buf);
let qkv = Tensor::new(dt_embd, &[nt, (nh + nkvh + nkvh) * dh]);
Expand All @@ -177,10 +180,9 @@ where
let mut qkv = qkv.clone().map(|_| buf);
{
let [scale, bias] = self.weights.attn_qkv(iblk, queue);
let cols = bias.shape()[0];
let bias = bias.tile(0, &[1, cols]).broadcast(0, nt);
let bias = bias.broadcast(0, nt);
self.rearrange(&mut qkv, &bias, workspace, queue_alloc)?;
self.mat_mul(&mut qkv, 1., &x1, &scale, 1., workspace, queue_alloc)?;
self.mat_mul(&mut qkv, 1., &x1, &scale, 1., workspace, queue_alloc)?
}
let qkv = qkv.tile(1, &[nh + nkvh + nkvh, dh]);
split!(qkv => q, k, v; [nh, nkvh, nkvh] @ 1);
Expand Down Expand Up @@ -215,14 +217,13 @@ where
req.pos,
workspace,
queue_alloc,
)?;
)?
}
}
{
let o = q.map_slice().merge(1..3).unwrap();
let [scale, bias] = self.weights.attn_o(iblk, queue);
let cols = bias.shape()[0];
let bias = bias.tile(0, &[1, cols]).broadcast(0, nt);
let bias = bias.broadcast(0, nt);
self.rearrange(&mut x1, &bias, workspace, queue_alloc)?;
self.mat_mul(&mut x1, 1., &o, &scale, 1., workspace, queue_alloc)?;
}
Expand Down Expand Up @@ -506,50 +507,40 @@ where
}

struct WeightDecorator<W> {
attn_norm_w: Tensor<usize>,
attn_norm_b: Tensor<usize>,
pos_embd: Tensor<usize>,
output_weight: Tensor<usize>,
norm: Tensor<usize>,

attn_qkv_w: Tensor<usize>,
attn_qkv_b: Tensor<usize>,
attn_o_w: Tensor<usize>,
attn_o_b: Tensor<usize>,

ffn_norm_w: Tensor<usize>,
ffn_norm_b: Tensor<usize>,
ffn_up_w: Tensor<usize>,
ffn_up_b: Tensor<usize>,
ffn_down_w: Tensor<usize>,
ffn_down_b: Tensor<usize>,

output_norm_w: Tensor<usize>,
output_norm_b: Tensor<usize>,
output_weight: Tensor<usize>,
pos_embd: Tensor<usize>,

weights: W,
}

impl Gpt2Meta {
fn decorator<W>(&self, weights: W) -> WeightDecorator<W> {
use crate::TensorUsage::Computation;
WeightDecorator {
attn_norm_w: self.attn_norm_w(),
attn_norm_b: self.attn_norm_b(),
pos_embd: self.pos_embd(),
output_weight: self.output_weight(),
norm: self.norm(),

attn_qkv_w: self.attn_qkv_w(Computation),
attn_qkv_b: self.attn_qkv_b(),
attn_qkv_b: self.attn_qkv_b(Computation),
attn_o_w: self.attn_o_w(Computation),
attn_o_b: self.attn_o_b(),
attn_o_b: self.attn_o_b(Computation),

ffn_norm_w: self.ffn_norm_w(),
ffn_norm_b: self.ffn_norm_b(),
ffn_up_w: self.ffn_up_w(Computation),
ffn_up_b: self.ffn_up_b(),
ffn_up_b: self.ffn_up_b(Computation),
ffn_down_w: self.ffn_down_w(Computation),
ffn_down_b: self.ffn_down_b(),

output_norm_w: self.output_norm_w(),
output_norm_b: self.output_norm_b(),
output_weight: self.output_weight(),
pos_embd: self.pos_embd(),
ffn_down_b: self.ffn_down_b(Computation),

weights,
}
Expand All @@ -563,10 +554,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
queue: &QueueOf<W::Hardware>,
) -> [Tensor<W::Memory<'_>>; 2] {
let [w, b] = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue);
[
self.attn_norm_w.clone().map(|_| w),
self.attn_norm_b.clone().map(|_| b),
]
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
}

pub fn attn_qkv(
Expand Down Expand Up @@ -595,10 +583,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
queue: &QueueOf<W::Hardware>,
) -> [Tensor<W::Memory<'_>>; 2] {
let [w, b] = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue);
[
self.ffn_norm_w.clone().map(|_| w),
self.ffn_norm_b.clone().map(|_| b),
]
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
}

pub fn ffn_up(&self, iblk: usize, queue: &QueueOf<W::Hardware>) -> [Tensor<W::Memory<'_>>; 2] {
Expand All @@ -623,10 +608,7 @@ impl<W: WeightLoader> WeightDecorator<W> {

pub fn output_norm(&self, queue: &QueueOf<W::Hardware>) -> [Tensor<W::Memory<'_>>; 2] {
let [w, b] = self.weights.output_norm(queue);
[
self.output_norm_w.clone().map(|_| w),
self.output_norm_b.clone().map(|_| b),
]
[self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)]
}

pub fn output_weight(&self, queue: &QueueOf<W::Hardware>) -> Tensor<W::Memory<'_>> {
Expand Down
73 changes: 28 additions & 45 deletions models/gpt2/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,71 +84,54 @@ impl Gpt2Meta {
pub fn position_embd(&self) -> Tensor<usize> {
self.embd(self.nctx)
}
// ln1_weight
pub fn attn_norm_w(&self) -> Tensor<usize> {
self.norm()
}
// ln1_bias
pub fn attn_norm_b(&self) -> Tensor<usize> {
self.norm()
}
// attn_qkvw

pub fn attn_qkv_w(&self, usage: TensorUsage) -> Tensor<usize> {
self.mat(3 * self.d, self.d, usage)
let &Self { d, .. } = self;
self.mat(3 * d, d, usage)
}
// attn_qkvb
pub fn attn_qkv_b(&self) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[3 * self.d])

pub fn attn_qkv_b(&self, usage: TensorUsage) -> Tensor<usize> {
let &Self { d, .. } = self;
self.mat(3 * d, 1, usage)
}
// attn_projw

pub fn attn_o_w(&self, usage: TensorUsage) -> Tensor<usize> {
self.mat(self.d, self.d, usage)
let &Self { d, .. } = self;
self.mat(d, d, usage)
}
// attn_projb
pub fn attn_o_b(&self) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[self.d])
}
// ln2_weight
pub fn ffn_norm_w(&self) -> Tensor<usize> {
self.norm()
}
// ln2_bias
pub fn ffn_norm_b(&self) -> Tensor<usize> {
self.norm()

pub fn attn_o_b(&self, usage: TensorUsage) -> Tensor<usize> {
let &Self { d, .. } = self;
self.mat(d, 1, usage)
}
// fcw

pub fn ffn_up_w(&self, usage: TensorUsage) -> Tensor<usize> {
self.mat(4 * self.d, self.d, usage)
let &Self { d, di, .. } = self;
self.mat(di, d, usage)
}
// fcb
pub fn ffn_up_b(&self) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[4 * self.d])

pub fn ffn_up_b(&self, _usage: TensorUsage) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[self.di])
}
// fcprojw

pub fn ffn_down_w(&self, usage: TensorUsage) -> Tensor<usize> {
self.mat(self.d, 4 * self.d, usage)
let &Self { d, di, .. } = self;
self.mat(d, di, usage)
}
// fcprojb
pub fn ffn_down_b(&self) -> Tensor<usize> {

pub fn ffn_down_b(&self, _usage: TensorUsage) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[self.d])
}
// lnfw
pub fn output_norm_w(&self) -> Tensor<usize> {
self.norm()
}
// lnfb
pub fn output_norm_b(&self) -> Tensor<usize> {
self.norm()
}
// output.weight

pub fn output_weight(&self) -> Tensor<usize> {
Tensor::new(self.dt_embd, &[self.nvoc, self.d])
}

fn norm(&self) -> Tensor<usize> {
pub fn norm(&self) -> Tensor<usize> {
let &Self { dt_norm, d, .. } = self;
Tensor::new(dt_norm, &[d])
}

pub fn pos_embd(&self) -> Tensor<usize> {
let &Self { nvoc, d, .. } = self;
Tensor::new(self.dt_embd, &[nvoc, d])
Expand Down
24 changes: 12 additions & 12 deletions models/gpt2/common/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ impl<'a> Storage<&'a [u8]> {
#[rustfmt::skip]
let blocks = (0..meta.nblk)
.map(|i| BlkStorage {
attn_qkv_b: gguf.tensors[&*format!("blk.{i}.attn_qkv.bias" )].data,
attn_qkv_w: gguf.tensors[&*format!("blk.{i}.attn_qkv.weight" )].data,
attn_o_b: gguf.tensors[&*format!("blk.{i}.attn_output.bias" )].data,
attn_o_w: gguf.tensors[&*format!("blk.{i}.attn_output.weight")].data,
attn_norm_b: gguf.tensors[&*format!("blk.{i}.attn_norm.bias" )].data,
attn_norm_w: gguf.tensors[&*format!("blk.{i}.attn_norm.weight" )].data,
attn_norm_w: gguf.tensors[&*format!("blk.{i}.attn_norm.weight" )].data,
attn_norm_b: gguf.tensors[&*format!("blk.{i}.attn_norm.bias" )].data,
attn_qkv_w: gguf.tensors[&*format!("blk.{i}.attn_qkv.weight" )].data,
attn_qkv_b: gguf.tensors[&*format!("blk.{i}.attn_qkv.bias" )].data,
attn_o_w: gguf.tensors[&*format!("blk.{i}.attn_output.weight")].data,
attn_o_b: gguf.tensors[&*format!("blk.{i}.attn_output.bias" )].data,

ffn_up_b: gguf.tensors[&*format!("blk.{i}.ffn_up.bias" )].data,
ffn_up_w: gguf.tensors[&*format!("blk.{i}.ffn_up.weight" )].data,
ffn_down_b: gguf.tensors[&*format!("blk.{i}.ffn_down.bias" )].data,
ffn_down_w: gguf.tensors[&*format!("blk.{i}.ffn_down.weight" )].data,
ffn_norm_b: gguf.tensors[&*format!("blk.{i}.ffn_norm.bias" )].data,
ffn_norm_w: gguf.tensors[&*format!("blk.{i}.ffn_norm.weight" )].data,
ffn_norm_w: gguf.tensors[&*format!("blk.{i}.ffn_norm.weight" )].data,
ffn_norm_b: gguf.tensors[&*format!("blk.{i}.ffn_norm.bias" )].data,
ffn_up_w: gguf.tensors[&*format!("blk.{i}.ffn_up.weight" )].data,
ffn_up_b: gguf.tensors[&*format!("blk.{i}.ffn_up.bias" )].data,
ffn_down_w: gguf.tensors[&*format!("blk.{i}.ffn_down.weight" )].data,
ffn_down_b: gguf.tensors[&*format!("blk.{i}.ffn_down.bias" )].data,
})
.collect();

Expand Down

0 comments on commit 430044c

Please sign in to comment.