diff --git a/models/clip/common-cpu/src/test_infer.rs b/models/clip/common-cpu/src/infer.rs similarity index 97% rename from models/clip/common-cpu/src/test_infer.rs rename to models/clip/common-cpu/src/infer.rs index a5ef1f0..85f2cbd 100644 --- a/models/clip/common-cpu/src/test_infer.rs +++ b/models/clip/common-cpu/src/infer.rs @@ -25,7 +25,7 @@ fn test_infer() { println!("{meta:#?}"); let &ClipMeta { - dt_embd, + dt, d_image, d_patch, @@ -42,7 +42,7 @@ fn test_infer() { let time = Instant::now(); let slices = image .slice_uhd(9, d_image, d_patch) - .normalize(dt_embd, image_mean, image_std); + .normalize(dt, image_mean, image_std); println!("slice image {:?}", time.elapsed()); let weights = Weights::new(&storage); diff --git a/models/clip/common-cpu/src/lib.rs b/models/clip/common-cpu/src/lib.rs index ac3fe16..e830e31 100644 --- a/models/clip/common-cpu/src/lib.rs +++ b/models/clip/common-cpu/src/lib.rs @@ -1,6 +1,6 @@ -use clip::{ClipStorage, WeightLoader}; -use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode}; -use std::marker::PhantomData; +use clip::{BlkWeight, ClipBlkStorage, ClipStorage, Tensor, WeightLoader}; +use operators::{common_cpu::Cpu, conv, ByteOf, QueueOf, TopoNode}; +use std::{marker::PhantomData, ops::Deref}; pub struct Operators(PhantomData); @@ -21,7 +21,16 @@ where type TopoNode = Cpu; type Conv = conv::common_cpu::ConvIm2Col; type AddRows = op!(add_rows); + type Rearrange = op!(rearrange); type LayerNorm = op!(layer_norm); + type MatMul = op!(mat_mul); + + fn debug(tensor: &Tensor) + where + T: Deref]>, + { + println!("{tensor}") + } } impl<'w> Weights<'w> { @@ -32,18 +41,48 @@ impl<'w> Weights<'w> { impl WeightLoader for Weights<'_> { type Hardware = Cpu; - type Weight<'s> + type Memory<'s> = &'s [u8] where Self: 's; + fn load_blk( + &self, + which: BlkWeight, + iblk: usize, + _queue: &QueueOf, + ) -> [Self::Memory<'_>; 2] { + let ClipBlkStorage { + attn_norm_w, + attn_norm_b, + attn_qkv_w, + attn_qkv_b, + attn_o_w, + attn_o_b, + ffn_norm_w, + ffn_norm_b, + ffn_up_w, + ffn_up_b, + ffn_down_w, + ffn_down_b, + } = &self.0.blocks[iblk]; + match which { + BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b], + BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b], + BlkWeight::AttnO => [attn_o_w, attn_o_b], + BlkWeight::FfnNorm => [ffn_norm_w, ffn_norm_b], + BlkWeight::FfnUp => [ffn_up_w, ffn_up_b], + BlkWeight::FfnDown => [ffn_down_w, ffn_down_b], + } + } + #[inline] - fn patch_embd<'a>(&'a self, _queue: &'a QueueOf) -> [Self::Weight<'a>; 2] { + fn patch_embd<'a>(&'a self, _queue: &'a QueueOf) -> [Self::Memory<'a>; 2] { [self.0.patch_embd_w, self.0.patch_embd_b] } #[inline] - fn pos_embd<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + fn pos_embd<'a>(&'a self, _queue: &'a QueueOf) -> Self::Memory<'a> { self.0.pos_embd } @@ -51,7 +90,7 @@ impl WeightLoader for Weights<'_> { fn pre_norm<'a>( &'a self, _queue: &'a QueueOf, - ) -> Option<[Self::Weight<'a>; 2]> { + ) -> Option<[Self::Memory<'a>; 2]> { self.0.pre_norm } @@ -59,10 +98,10 @@ impl WeightLoader for Weights<'_> { fn post_norm<'a>( &'a self, _queue: &'a QueueOf, - ) -> Option<[Self::Weight<'a>; 2]> { + ) -> Option<[Self::Memory<'a>; 2]> { self.0.post_norm } } #[cfg(test)] -mod test_infer; +mod infer; diff --git a/models/clip/common/src/compute.rs b/models/clip/common/src/compute.rs index 53ebfdf..243308b 100644 --- a/models/clip/common/src/compute.rs +++ b/models/clip/common/src/compute.rs @@ -3,33 +3,58 @@ use operators::{ add_rows::{self, AddRows}, conv::{self, Conv}, layer_norm::{self, LayerNorm}, - ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, + mat_mul::{self, MatMul}, + rearrange::{self, Rearrange}, + ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace, }; use std::{ ops::{Deref, DerefMut}, time::Instant, }; -use tensor::Tensor; +use tensor::{split, Tensor}; pub trait Operators { type Hardware: Hardware; type TopoNode: TopoNode; type Conv: Conv; type AddRows: AddRows; + type Rearrange: Rearrange; type LayerNorm: LayerNorm; + type MatMul: MatMul; + + fn debug(tensor: &Tensor) + where + T: Deref]>; +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum BlkWeight { + AttnNorm, + AttnQKV, + AttnO, + FfnNorm, + FfnUp, + FfnDown, } pub trait WeightLoader { type Hardware: Hardware; - type Weight<'s>: Deref]> + 's + type Memory<'s>: Deref]> + 's where Self: 's; - fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Self::Weight<'a>; 2]; - fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Self::Weight<'a>; - fn pre_norm<'a>(&'a self, queue: &'a QueueOf) -> Option<[Self::Weight<'a>; 2]>; + fn load_blk( + &self, + which: BlkWeight, + iblk: usize, + queue: &QueueOf, + ) -> [Self::Memory<'_>; 2]; + + fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Self::Memory<'a>; 2]; + fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Self::Memory<'a>; + fn pre_norm<'a>(&'a self, queue: &'a QueueOf) -> Option<[Self::Memory<'a>; 2]>; fn post_norm<'a>(&'a self, queue: &'a QueueOf) - -> Option<[Self::Weight<'a>; 2]>; + -> Option<[Self::Memory<'a>; 2]>; } pub struct ClipWorker { @@ -37,7 +62,9 @@ pub struct ClipWorker { weights: WeightDecorator, conv: Ops::Conv, add_rows: Ops::AddRows, + rearrange: Ops::Rearrange, layer_norm: Ops::LayerNorm, + mat_mul: Ops::MatMul, pub debug: bool, } @@ -49,7 +76,9 @@ impl ClipWorker { meta, conv: Ops::Conv::new(processor), add_rows: Ops::AddRows::new(processor), + rearrange: Ops::Rearrange::new(processor), layer_norm: Ops::LayerNorm::new(processor), + mat_mul: Ops::MatMul::new(processor), debug: true, } } @@ -58,6 +87,14 @@ impl ClipWorker { pub const fn meta(&self) -> &ClipMeta { &self.meta } + + pub fn workspace_size(&self, np: usize) -> usize { + let ClipMeta { dt, d, .. } = self.meta; + let ele = dt.nbytes(); + let embd = np * d * ele; + let qkv = np * 3 * d * ele; + embd + qkv + } } impl ClipWorker @@ -77,9 +114,12 @@ where { let time = Instant::now(); let Args { raw, pos } = args; - let queue = queue_alloc.queue(); + let ClipMeta { + dt, nblk, d, nh, .. + } = self.meta; + let dh = d / nh; - let ClipMeta { dt_embd, .. } = self.meta; + let queue = queue_alloc.queue(); let [k, b] = self.weights.patch_embd(queue); let &[n, _, h, w] = raw.shape() else { @@ -89,28 +129,73 @@ where unreachable!() }; - let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s)); + let mut embd = Tensor::new(dt, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s)); self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)?; - let mut embd = embd.merge(2..4).unwrap().transpose(&[2, 1]); + let embd_ = embd.merge(2..4).unwrap().transpose(&[2, 1]); + let mut embd = Tensor::new(embd_.dt(), embd_.shape()).map(|s| queue_alloc.alloc(s)); + self.rearrange(&mut embd, &embd_, workspace, queue_alloc)?; let pos_embd = self.weights.pos_embd(queue); self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)?; + let mut x = embd.merge(0..2).unwrap(); + + let np = x.shape()[0]; + let workspace_size = self.workspace_size(np); + let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size); + + let x1 = Tensor::new(x.dt(), x.shape()); + let (buf, workspace) = workspace.split_at_mut(*x1.get()); + let mut x1 = x1.map(|_| buf); + let qkv = Tensor::new(dt, &[np, 3 * d]); + if let Some([scale, bias]) = self.weights.pre_norm(queue) { - let inplace = unsafe { embd.map_slice_static() }; - self.layer_norm(&mut embd, &inplace, &scale, &bias, workspace, queue_alloc)?; + let inplace = unsafe { x.map_slice_static() }; + self.layer_norm(&mut x, &inplace, &scale, &bias, workspace, queue_alloc)? } - for _ in 0..self.meta.nblk {} + for iblk in 0..nblk { + { + let [scale, bias] = self.weights.attn_norm(iblk, queue); + self.layer_norm(&mut x1, &x, &scale, &bias, workspace, queue_alloc)? + } + let (buf, workspace) = workspace.split_at_mut(*qkv.get()); + let mut qkv = qkv.clone().map(|_| buf); + { + let [scale, bias] = self.weights.attn_qkv(iblk, queue); + let bias = bias.broadcast(0, np); + self.rearrange(&mut qkv, &bias, workspace, queue_alloc)?; + self.mat_mul(&mut qkv, 1., &x1, &scale, 1., workspace, queue_alloc)? + } + let qkv = qkv.tile(1, &[3 * nh, dh]); + split!(qkv => q, k, v; [nh, nh, nh] @ 1); + let q = q; + let _k = k; + let _v = v; + { + // TODO: attention + } + { + let o = q.map_slice().merge(1..3).unwrap(); + let [scale, bias] = self.weights.attn_o(iblk, queue); + let bias = bias.broadcast(0, np); + self.rearrange(&mut x1, &bias, workspace, queue_alloc)?; + self.mat_mul(&mut x1, 1., &o, &scale, 1., workspace, queue_alloc)?; + } + + let [_w, _b] = self.weights.ffn_norm(iblk, queue); + let [_w, _b] = self.weights.ffn_up(iblk, queue); + let [_w, _b] = self.weights.ffn_down(iblk, queue); + } if let Some([scale, bias]) = self.weights.post_norm(queue) { - let inplace = unsafe { embd.map_slice_static() }; - self.layer_norm(&mut embd, &inplace, &scale, &bias, workspace, queue_alloc)?; + let inplace = unsafe { x.map_slice_static() }; + self.layer_norm(&mut x, &inplace, &scale, &bias, workspace, queue_alloc)? } if self.debug { - println!("encode {n} x {h} x {w} image in {:?}", time.elapsed()); + println!("encode {n} x {h} x {w} image in {:?}", time.elapsed()) } Ok(()) @@ -186,6 +271,30 @@ where ) } + fn rearrange( + &self, + dst: &mut Tensor, + src: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Dst: DerefMut]>, + Src: Deref]>, + QA: QueueAlloc, + { + self.rearrange.launch( + &rearrange::Args { + dst_layout: dst.layout(), + dst_base: dst.base_mut(), + src_layout: src.layout(), + src_base: src.base(), + }, + workspace, + queue_alloc, + ) + } + fn layer_norm( &self, y: &mut Tensor, @@ -218,14 +327,57 @@ where queue_alloc, ) } + + fn mat_mul( + &self, + c: &mut Tensor, + beta: f32, + a: &Tensor, + b: &Tensor, + alpha: f32, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + B: Deref]>, + QA: QueueAlloc, + { + self.mat_mul.launch( + &mat_mul::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + beta, + a_layout: a.layout(), + a_base: a.base(), + b_layout: b.layout(), + b_base: b.base(), + alpha, + }, + workspace, + queue_alloc, + ) + } } struct WeightDecorator { - weights: W, patch_embd_w: Tensor, patch_embd_b: Tensor, pos_embd: Tensor, norm: Tensor, + + attn_qkv_w: Tensor, + attn_qkv_b: Tensor, + attn_o_w: Tensor, + attn_o_b: Tensor, + + ffn_up_w: Tensor, + ffn_up_b: Tensor, + ffn_down_w: Tensor, + ffn_down_b: Tensor, + + weights: W, } impl ClipMeta { @@ -235,6 +387,16 @@ impl ClipMeta { patch_embd_b: self.patch_embd_b(), pos_embd: self.pos_embd(), norm: self.norm(), + + attn_qkv_w: self.attn_qkv_w(), + attn_qkv_b: self.attn_qkv_b(), + attn_o_w: self.attn_o_w(), + attn_o_b: self.attn_o_b(), + ffn_up_w: self.ffn_up_w(), + ffn_up_b: self.ffn_up_b(), + ffn_down_w: self.ffn_down_w(), + ffn_down_b: self.ffn_down_b(), + weights, } } @@ -242,7 +404,7 @@ impl ClipMeta { impl WeightDecorator { #[inline] - pub fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Tensor>; 2] { + pub fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Tensor>; 2] { let [w, b] = self.weights.patch_embd(queue); [ self.patch_embd_w.clone().map(|_| w), @@ -251,7 +413,7 @@ impl WeightDecorator { } #[inline] - pub fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { + pub fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { let pos_embd = self.weights.pos_embd(queue); self.pos_embd.clone().map(|_| pos_embd) } @@ -260,7 +422,7 @@ impl WeightDecorator { pub fn pre_norm<'a>( &'a self, queue: &'a QueueOf, - ) -> Option<[Tensor>; 2]> { + ) -> Option<[Tensor>; 2]> { self.weights .pre_norm(queue) .map(|pair| pair.map(|w| self.norm.clone().map(|_| w))) @@ -270,9 +432,67 @@ impl WeightDecorator { pub fn post_norm<'a>( &'a self, queue: &'a QueueOf, - ) -> Option<[Tensor>; 2]> { + ) -> Option<[Tensor>; 2]> { self.weights .post_norm(queue) .map(|pair| pair.map(|w| self.norm.clone().map(|_| w))) } + + pub fn attn_norm( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue); + [self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)] + } + + pub fn attn_qkv( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue); + [ + self.attn_qkv_w.clone().map(|_| w), + self.attn_qkv_b.clone().map(|_| b), + ] + } + + pub fn attn_o(&self, iblk: usize, queue: &QueueOf) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnO, iblk, queue); + [ + self.attn_o_w.clone().map(|_| w), + self.attn_o_b.clone().map(|_| b), + ] + } + + pub fn ffn_norm( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue); + [self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)] + } + + pub fn ffn_up(&self, iblk: usize, queue: &QueueOf) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnUp, iblk, queue); + [ + self.ffn_up_w.clone().map(|_| w), + self.ffn_up_b.clone().map(|_| b), + ] + } + + pub fn ffn_down( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnDown, iblk, queue); + [ + self.ffn_down_w.clone().map(|_| w), + self.ffn_down_b.clone().map(|_| b), + ] + } } diff --git a/models/clip/common/src/lib.rs b/models/clip/common/src/lib.rs index d391649..ce041a9 100644 --- a/models/clip/common/src/lib.rs +++ b/models/clip/common/src/lib.rs @@ -6,9 +6,9 @@ mod storage; use gguf::ggml_quants::digit_layout::DigitLayout; pub use args::Args as ClipArgs; -pub use compute::{ClipWorker, Operators, WeightLoader}; +pub use compute::{BlkWeight, ClipWorker, Operators, WeightLoader}; pub use image::{Image, ImageGrid}; -pub use storage::Storage as ClipStorage; +pub use storage::{BlkStorage as ClipBlkStorage, Storage as ClipStorage}; pub use tensor::Tensor; pub mod ext { pub use gguf::{ @@ -23,8 +23,6 @@ pub struct ClipMeta { pub minicpmv_version: u8, pub dt: DigitLayout, - pub dt_embd: DigitLayout, - pub dt_norm: DigitLayout, pub nblk: usize, pub d_patch: usize, @@ -91,11 +89,56 @@ impl ClipMeta { pub fn pos_embd(&self) -> Tensor { let &Self { d, .. } = self; - Tensor::new(self.dt_embd, &[D_POS_EMBD.pow(2), d]) + Tensor::new(self.dt, &[D_POS_EMBD.pow(2), d]) } pub fn norm(&self) -> Tensor { let &Self { d, .. } = self; - Tensor::new(self.dt_norm, &[d]) + Tensor::new(self.dt, &[d]) + } + + pub fn attn_qkv_w(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(3 * d, d) + } + + pub fn attn_qkv_b(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(3 * d, 1) + } + + pub fn attn_o_w(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(d, d) + } + + pub fn attn_o_b(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(d, 1) + } + + pub fn ffn_up_w(&self) -> Tensor { + let &Self { d, di, .. } = self; + Tensor::new(self.dt, &[di, d]) + } + + pub fn ffn_up_b(&self) -> Tensor { + let &Self { di, .. } = self; + self.mat(di, 1) + } + + pub fn ffn_down_w(&self) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(d, di) + } + + pub fn ffn_down_b(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(d, 1) + } + + fn mat(&self, row: usize, col: usize) -> Tensor { + assert_eq!(self.dt.group_size(), 1); + Tensor::new(self.dt, &[row, col]).transpose(&[1, 0]) } } diff --git a/models/clip/common/src/storage.rs b/models/clip/common/src/storage.rs index 2977aa2..b9f34f3 100644 --- a/models/clip/common/src/storage.rs +++ b/models/clip/common/src/storage.rs @@ -9,13 +9,29 @@ pub struct Storage { pub pos_embd: T, pub pre_norm: Option<[T; 2]>, pub post_norm: Option<[T; 2]>, + pub blocks: Box<[BlkStorage]>, +} + +#[derive(Clone, Copy)] +pub struct BlkStorage { + pub attn_norm_w: T, + pub attn_norm_b: T, + pub attn_qkv_w: T, + pub attn_qkv_b: T, + pub attn_o_w: T, + pub attn_o_b: T, + + pub ffn_norm_w: T, + pub ffn_norm_b: T, + pub ffn_up_w: T, + pub ffn_up_b: T, + pub ffn_down_w: T, + pub ffn_down_b: T, } impl<'a> Storage<&'a [u8]> { pub fn from_gguf(gguf: &GGufModel<'a>) -> Self { let pos_embd = &gguf.tensors["v.position_embd.weight"]; - let patch_embd_w = &gguf.tensors["v.patch_embd.weight"]; - let patch_embd_b = &gguf.tensors["v.patch_embd.bias"]; let projector = match gguf.get_str("clip.projector_type").unwrap() { "mlp" => ProjectorType::Mlp, @@ -24,15 +40,12 @@ impl<'a> Storage<&'a [u8]> { "resampler" => ProjectorType::Resampler, _ => ProjectorType::Unknown, }; - #[rustfmt::skip] let meta = ClipMeta { projector, minicpmv_version: gguf.get_usize("clip.minicpmv_version").unwrap() as _, - dt : patch_embd_w.ty, - dt_embd: pos_embd.ty, - dt_norm: gguf.tensors["v.blk.0.ln1.weight"].ty, + dt : pos_embd.ty, nblk : gguf.get_usize("clip.vision.block_count" ).unwrap(), d_patch: gguf.get_usize("clip.vision.patch_size" ).unwrap(), @@ -45,11 +58,29 @@ impl<'a> Storage<&'a [u8]> { image_std : get_rgb(gguf, "clip.vision.image_std" ), epsilon : gguf.get_f32("clip.vision.attention.layer_norm_epsilon").unwrap(), }; + #[rustfmt::skip] + let blocks = (0..meta.nblk) + .map(|i| BlkStorage { + attn_norm_w: gguf.tensors[&*format!("v.blk.{i}.ln1.weight" )].data, + attn_norm_b: gguf.tensors[&*format!("v.blk.{i}.ln1.bias" )].data, + attn_qkv_w: gguf.tensors[&*format!("v.blk.{i}.attn_qkv.weight")].data, + attn_qkv_b: gguf.tensors[&*format!("v.blk.{i}.attn_qkv.bias" )].data, + attn_o_w: gguf.tensors[&*format!("v.blk.{i}.attn_out.weight")].data, + attn_o_b: gguf.tensors[&*format!("v.blk.{i}.attn_out.bias" )].data, + + ffn_norm_w: gguf.tensors[&*format!("v.blk.{i}.ln2.weight" )].data, + ffn_norm_b: gguf.tensors[&*format!("v.blk.{i}.ln2.bias" )].data, + ffn_up_w: gguf.tensors[&*format!("v.blk.{i}.ffn_up.weight" )].data, + ffn_up_b: gguf.tensors[&*format!("v.blk.{i}.ffn_up.bias" )].data, + ffn_down_w: gguf.tensors[&*format!("v.blk.{i}.ffn_down.weight")].data, + ffn_down_b: gguf.tensors[&*format!("v.blk.{i}.ffn_down.bias" )].data, + }) + .collect(); Self { meta, - patch_embd_w: patch_embd_w.data, - patch_embd_b: patch_embd_b.data, + patch_embd_w: gguf.tensors["v.patch_embd.weight"].data, + patch_embd_b: gguf.tensors["v.patch_embd.bias"].data, pos_embd: pos_embd.data, pre_norm: gguf .tensors @@ -59,6 +90,7 @@ impl<'a> Storage<&'a [u8]> { .tensors .get("v.post_ln.weight") .map(|w| [w.data, gguf.tensors["v.post_ln.bias"].data]), + blocks, } } }