Skip to content

Commit

Permalink
feat(clip): 分片图片批量编码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 27, 2024
1 parent 2e5d211 commit 9bfe2a5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ itertools = "0.13"
build-script-cfg = "0.0"

ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "48d36c5" }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "1b08473", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "d73a53e", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "e2ec203" }
Expand Down
24 changes: 10 additions & 14 deletions models/clip/common-cpu/src/test_infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,15 @@ fn test_infer() {
)
.unwrap();

let [x, y] = slices.grid();
for i in 0..y {
for j in 0..x {
let patch = slices.patch(j, i);
worker
.launch(
ClipArgs {
raw: patch.to_nchw(),
},
&mut [],
&ThisThread,
)
.unwrap();
}
if let Some(patches) = slices.patches_nchw() {
worker
.launch(
ClipArgs {
raw: patches.map_slice(),
},
&mut [],
&ThisThread,
)
.unwrap();
}
}
14 changes: 12 additions & 2 deletions models/clip/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use operators::{
conv::{self, Conv},
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode,
};
use std::ops::{Deref, DerefMut};
use std::{
ops::{Deref, DerefMut},
time::Instant,
};
use tensor::Tensor;

pub trait Operators {
Expand Down Expand Up @@ -60,6 +63,7 @@ where
where
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let time = Instant::now();
let Args { raw } = args;
let queue = queue_alloc.queue();

Expand All @@ -74,7 +78,13 @@ where
};

let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s));
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)
self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)?;

if self.debug {
println!("encode {n} x {h} x {w} image in {:?}", time.elapsed());
}

Ok(())
}
}

Expand Down
34 changes: 27 additions & 7 deletions models/clip/common/src/image.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use def::*;
use common::{borrow, own, Contiguous};
use def::*;
use gguf::ggml_quants::{
digit_layout::{types as ty, DigitLayout},
f16,
Expand All @@ -7,7 +8,7 @@ use image::ImageReader;
use itertools::izip;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::{iter::zip, ops::Deref, path::Path, slice::from_raw_parts_mut};
use tensor::{Blob, Tensor};
use tensor::{rearrange, Blob, Tensor};

#[repr(transparent)]
pub struct Image<T>(Tensor<T>);
Expand Down Expand Up @@ -161,11 +162,7 @@ where

/// NHWC rgb Tensor -> NCHW value Tensor
pub fn to_nchw(&self) -> Tensor<&[u8]> {
self.0
.destruct_array()
.map(|t| &**t)
.transpose(&[2, 0, 1])
.tile(0, &[1, 3])
rgb_to_chw(&self.0).tile(0, &[1, 3])
}
}

Expand Down Expand Up @@ -198,6 +195,19 @@ impl ImageGrid {
)
}

pub fn patches_nchw(&self) -> Option<Tensor<Contiguous<Blob>>> {
self.grid.as_ref().map(|data| {
let xychw = rgb_to_chw(data);
if let Some(nchw) = xychw.as_ref().merge(0..2) {
nchw.map(|s| borrow(s))
} else {
let mut blob = Tensor::new(xychw.dt(), xychw.shape()).map(Blob::new);
rearrange(&mut blob, &xychw);
blob.merge(0..2).unwrap().map(own)
}
})
}

/// [urgb] 转 [frgb]
pub fn normalize(&self, dt: DigitLayout, mean: frgb96, std: frgb96) -> Self {
let dt = match dt {
Expand Down Expand Up @@ -317,6 +327,16 @@ where
ans
}

fn rgb_to_chw<T>(data: &Tensor<T>) -> Tensor<&[u8]>
where
T: Deref<Target = [u8]>,
{
let ndim = data.shape().len();
data.map_slice()
.destruct_array()
.transpose(&[ndim, ndim - 2, ndim - 1])
}

#[test]
fn test() {
use std::time::Instant;
Expand Down
23 changes: 17 additions & 6 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,30 @@ impl Tensor<usize> {
/// access
impl<T> Tensor<T> {
/// 打开数组数据类型
pub fn destruct_array(&self) -> Tensor<&T> {
pub fn destruct_array(self) -> Self {
use ggus::ggml_quants::digit_layout::LayoutContent::{Real, Unsigned};
use std::iter::once;

let len = self.dt.group_size();
let dt = match self.dt.decode() {
let Self {
dt,
layout,
physical,
} = self;

let len = dt.group_size();
let dt = match dt.decode() {
Unsigned { width } if len > 1 => DigitLayout::unsigned(width as _, 1),
Real { exponent, mantissa } if len > 1 => {
DigitLayout::real(exponent as _, mantissa as _, 1)
}
_ => return self.as_ref(),
_ => {
return Self {
dt,
layout,
physical,
}
}
};
let layout = &self.layout;
let shape = layout
.shape()
.iter()
Expand All @@ -63,7 +74,7 @@ impl<T> Tensor<T> {
Tensor {
dt,
layout: ArrayLayout::new(&shape, &strides, offset),
physical: &self.physical,
physical,
}
}

Expand Down

0 comments on commit 9bfe2a5

Please sign in to comment.