Skip to content

Commit

Permalink
feat(clip): 支持图片归一化到不同的精度
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 22, 2024
1 parent 897470e commit 843f8b4
Showing 1 changed file with 57 additions and 36 deletions.
93 changes: 57 additions & 36 deletions models/clip/common/src/image.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use def::*;
use gguf::ggml_quants::{
digit_layout::{types as ty, DigitLayout},
f16,
};
use image::ImageReader;
use itertools::izip;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::{iter::zip, ops::Deref, path::Path, slice::from_raw_parts_mut};
use tensor::{rearrange, Blob, Tensor};
use tensor::{Blob, Tensor};

#[repr(transparent)]
pub struct Image<T>(Tensor<T>);
Expand All @@ -15,25 +19,26 @@ pub struct ImageGrid {

#[allow(non_camel_case_types)]
mod def {
use gguf::ggml_quants::digit_layout::layout;
layout!(Urgb u(8) ; 3);
layout!(Frgb e(8)m(23); 3);
use gguf::ggml_quants::{digit_layout::layout, f16};
layout!(Urgb24 u(8) ; 3);
layout!(Frgb48 e(5)m(10); 3);
layout!(Frgb96 e(8)m(23); 3);

pub type urgb24 = [u8; 3];
pub type frgb48 = [f16; 3];
pub type frgb96 = [f32; 3];

/// 有图像尺寸参与的浮点运算应使用这个类型
pub type fdim = f64;
/// 基本的 rgb 类型
pub type urgb = [u8; 3];
/// 归一化浮点表示的 rgb 类型
pub type frgb = [f32; 3];
}

impl Image<Vec<u8>> {
/// 从文件加载
pub fn load(path: impl AsRef<Path>) -> Self {
let rgb8 = ImageReader::open(path).unwrap().decode().unwrap().to_rgb8();
let (x, y) = rgb8.dimensions();
assert_eq!(rgb8.as_raw().len(), Urgb.nbytes() * (x * y) as usize);
Self(Tensor::new(Urgb, &[y as usize, x as usize]).map(|_| rgb8.into_raw()))
assert_eq!(rgb8.as_raw().len(), Urgb24.nbytes() * (x * y) as usize);
Self(Tensor::new(Urgb24, &[y as usize, x as usize]).map(|_| rgb8.into_raw()))
}
}

Expand Down Expand Up @@ -92,9 +97,9 @@ where

/// 双三次插值缩放
fn bicubic_resize(&self, [w_, h_]: [usize; 2]) -> Image<Blob> {
assert_eq!(self.0.dt(), Urgb);
assert_eq!(self.0.dt(), Urgb24);

let mut ans = Image(Tensor::new(Urgb, &[h_, w_]).map(Blob::new));
let mut ans = Image(Tensor::new(Urgb24, &[h_, w_]).map(Blob::new));
let data = ans.0.get_mut();
let ptr = data.as_mut_ptr() as usize;
let len = data.len();
Expand Down Expand Up @@ -156,17 +161,12 @@ where

/// NHWC rgb Tensor -> NCHW value Tensor
#[inline]
pub fn to_nchw(&self) -> Tensor<Blob> {
let src = self
.0
pub fn to_nchw(&self) -> Tensor<&[u8]> {
self.0
.destruct_array()
.map(|t| &**t)
.transpose(&[2, 0, 1])
.tile(0, &[1, 3]);

let mut ans = Tensor::new(src.dt(), src.shape()).map(Blob::new);
rearrange(&mut ans, &src);
ans
.tile(0, &[1, 3])
}
}

Expand Down Expand Up @@ -199,10 +199,18 @@ impl ImageGrid {
}

/// [urgb] 转 [frgb]
pub fn normalize(&self, mean: frgb, std: frgb) -> Self {
pub fn normalize(&self, dt: DigitLayout, mean: frgb96, std: frgb96) -> Self {
let dt = match dt {
ty::F16 => Frgb48,
ty::F32 => Frgb96,
_ => panic!("Unsupported type {dt}"),
};
Self {
grid: self.grid.as_ref().map(|data| normalize(data, mean, std)),
whole: Image(normalize(&self.whole.0, mean, std)),
grid: self
.grid
.as_ref()
.map(|data| normalize(data, dt, mean, std)),
whole: Image(normalize(&self.whole.0, dt, mean, std)),
}
}
}
Expand Down Expand Up @@ -273,26 +281,39 @@ fn refine_patch_size(
}

/// 将整型表示的 rgb 值转换为归一化浮点表示
fn normalize<T>(data: &Tensor<T>, mean: frgb, std: frgb) -> Tensor<Blob>
fn normalize<T>(data: &Tensor<T>, dt: DigitLayout, mean: frgb96, std: frgb96) -> Tensor<Blob>
where
T: Deref<Target = [u8]>,
{
assert_eq!(data.dt(), Urgb);
let mut ans = data.cast(Frgb).map(Blob::new);

let ([], src, []) = (unsafe { data.get().align_to::<urgb>() }) else {
unreachable!()
};
let ([], dst, []) = (unsafe { ans.get_mut().align_to_mut::<frgb>() }) else {
assert_eq!(data.dt(), Urgb24);
let ([], src, []) = (unsafe { data.get().align_to::<urgb24>() }) else {
unreachable!()
};

for (dst, src) in zip(dst, src) {
for (dst, &src, mean, std) in izip!(dst, src, mean, std) {
*dst = (src as f32 / 255. - mean) / std;
let mut ans = data.cast(dt).map(Blob::new);
match dt {
def::Frgb48 => {
let ([], dst, []) = (unsafe { ans.get_mut().align_to_mut::<frgb48>() }) else {
unreachable!()
};
for (dst, src) in zip(dst, src) {
for (dst, &src, mean, std) in izip!(dst, src, mean, std) {
*dst = f16::from_f32((src as f32 / 255. - mean) / std);
}
}
}
def::Frgb96 => {
let ([], dst, []) = (unsafe { ans.get_mut().align_to_mut::<frgb96>() }) else {
unreachable!()
};
for (dst, src) in zip(dst, src) {
for (dst, &src, mean, std) in izip!(dst, src, mean, std) {
*dst = (src as f32 / 255. - mean) / std;
}
}
}
_ => unreachable!(),
}

ans
}

Expand All @@ -311,7 +332,7 @@ fn test() {
let time = Instant::now();
let slices = image
.slice_uhd(9, 448, 14)
.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]);
.normalize(ty::F32, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]);
println!("slice image {:?}", time.elapsed());

let [x, y] = slices.grid();
Expand Down

0 comments on commit 843f8b4

Please sign in to comment.