Skip to content

Commit

Permalink
feat(model-parameters): 支持在内存中转换模型数据类型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 15, 2024
1 parent c1ccdcd commit 1c9e761
Show file tree
Hide file tree
Showing 9 changed files with 600 additions and 376 deletions.
83 changes: 83 additions & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,90 @@
use std::{fmt, io::Write};

/// `utok` for token id.
#[allow(non_camel_case_types)]
pub type utok = u32;

/// `upos` for position id.
#[allow(non_camel_case_types)]
pub type upos = u32;

#[macro_export]
macro_rules! slice {
($blob:expr; $width:expr; [$line:expr]) => {
$blob[$line * $width..][..$width]
};
}

pub fn write_tensor<T: fmt::LowerExp>(
to: &mut impl Write,
buf: &[T],
shape: &[usize],
) -> std::io::Result<()> {
match shape {
[] => {
writeln!(to, "<>")?;
write_matrix(to, buf, (1, 1))
}
[len] => {
writeln!(to, "<{len}>")?;
write_matrix(to, buf, (*len, 1))
}
[rows, cols] => {
writeln!(to, "<{rows}x{cols}>")?;
write_matrix(to, buf, (*rows, *cols))
}
[batch @ .., rows, cols] => {
let mut strides = vec![1usize; batch.len()];
for i in (1..batch.len()).rev() {
strides[i - 1] = strides[i] * batch[i];
}
let strides = strides.as_slice();
for i in 0..batch[0] * strides[0] {
let mut which = vec![0usize; strides.len()];
let mut rem = i;
for (j, &stride) in strides.iter().enumerate() {
which[j] = rem / stride;
rem %= stride;
}
writeln!(
to,
"<{rows}x{cols}>[{}]",
which
.iter()
.map(usize::to_string)
.collect::<Vec<_>>()
.join(", "),
)?;
write_matrix(to, &slice!(buf; rows * cols; [i]), (*rows, *cols))?;
}
Ok(())
}
}
}

fn write_matrix<T: fmt::LowerExp>(
to: &mut impl Write,
buf: &[T],
shape: (usize, usize),
) -> std::io::Result<()> {
let (rows, cols) = shape;
for r in 0..rows {
let row = &slice!(buf; cols; [r]);
for it in row {
write!(to, "{it:<8.3e} ")?;
}
writeln!(to)?;
}
Ok(())
}

#[test]
fn test_log() {
let array = [
1., 2., 3., //
4., 5., 6., //
7., 8., 9., //
10., 11., 12., //
];
write_tensor(&mut std::io::stdout(), &array, &[2, 2, 3]).unwrap();
}
1 change: 1 addition & 0 deletions model-parameters/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
[dependencies]
common = { path = "../common" }
log = "0.4"
half = "2.3"
memmap2 = "0.9"
safetensors = "0.4"
serde_json = "1.0"
Expand Down
73 changes: 73 additions & 0 deletions model-parameters/src/data_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use serde::{
de::{Unexpected, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum DataType {
F16,
BF16,
F32,
}

impl Serialize for DataType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::F16 => serializer.serialize_str("float16"),
Self::BF16 => serializer.serialize_str("bfloat16"),
Self::F32 => serializer.serialize_str("float32"),
}
}
}

struct DataTypeVisitor;

impl<'de> Visitor<'de> for DataTypeVisitor {
type Value = DataType;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"pytorch dtype string: \"float16\", \"bfloat16\", or \"float32\""
)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v {
"float16" => Ok(DataType::F16),
"bfloat16" => Ok(DataType::BF16),
"float32" => Ok(DataType::F32),
_ => Err(E::invalid_value(
Unexpected::Str(v),
&"\"float16\", \"bfloat16\", or \"float32\"",
)),
}
}
}

impl<'de> Deserialize<'de> for DataType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(DataTypeVisitor)
}
}

impl DataType {
#[inline]
pub const fn size(&self) -> usize {
match self {
Self::F16 | Self::BF16 => 2,
Self::F32 => 4,
}
}
}
55 changes: 37 additions & 18 deletions model-parameters/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
mod safe_tensors;
mod data_type;
mod memory;

pub enum DataType {
F16,
BF16,
F32,
}
#[macro_use]
extern crate log;

impl DataType {
#[inline]
pub const fn size(&self) -> usize {
match self {
DataType::F16 => 2,
DataType::BF16 => 2,
DataType::F32 => 4,
}
}
}
use common::utok;

pub trait LLama2 {
pub use data_type::DataType;

pub trait Llama2 {
fn hidden_size(&self) -> usize;
fn intermediate_size(&self) -> usize;
fn max_position_embeddings(&self) -> usize;
Expand All @@ -41,4 +32,32 @@ pub trait LLama2 {
fn lm_head(&self) -> &[u8];
}

pub use safe_tensors::SafeTensors;
pub use memory::{Memory, SafeTensorError};

#[derive(serde::Serialize, serde::Deserialize, Debug)]
struct ConfigJson {
pub bos_token_id: utok,
pub eos_token_id: utok,
pub hidden_size: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub vocab_size: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub torch_dtype: DataType,
}

struct LayerParamsOffset {
input_layernorm: usize,
self_attn_q_proj: usize,
self_attn_k_proj: usize,
self_attn_v_proj: usize,
self_attn_o_proj: usize,
post_attention_layernorm: usize,
mlp_gate: usize,
mlp_down: usize,
mlp_up: usize,
}
84 changes: 84 additions & 0 deletions model-parameters/src/memory/inside_memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use crate::{ConfigJson, DataType, LayerParamsOffset, Llama2, Memory};
use half::{bf16, f16};

impl Memory<Vec<u8>> {
pub fn cast<T: AsRef<[u8]>>(src: &Memory<T>, new_dtype: DataType) -> Self {
let mut blob = Vec::new();

let from = src.config.torch_dtype;
let mut append = |src: &[u8]| {
let start = blob.len();
let end = start + src.len() * new_dtype.size() / from.size();
blob.resize(end, 0);
cast(from, src, new_dtype, &mut blob[start..end]);
start
};

let embed_tokens = append(src.embed_tokens());
let layers = (0..src.config.num_hidden_layers)
.map(|layer| LayerParamsOffset {
input_layernorm: append(src.input_layernorm(layer)),
self_attn_q_proj: append(src.self_attn_q_proj(layer)),
self_attn_k_proj: append(src.self_attn_k_proj(layer)),
self_attn_v_proj: append(src.self_attn_v_proj(layer)),
self_attn_o_proj: append(src.self_attn_o_proj(layer)),
post_attention_layernorm: append(src.post_attention_layernorm(layer)),
mlp_gate: append(src.mlp_gate(layer)),
mlp_down: append(src.mlp_down(layer)),
mlp_up: append(src.mlp_up(layer)),
})
.collect();
let model_norm = append(src.model_norm());
let lm_head = append(src.lm_head());

Self {
config: ConfigJson {
torch_dtype: new_dtype,
..src.config
},
blob,
embed_tokens,
layers,
model_norm,
lm_head,
}
}
}

fn cast(src_dtype: DataType, src: &[u8], dst_dtype: DataType, dst: &mut [u8]) {
macro_rules! cast {
($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => {
let len = $src.len() / std::mem::size_of::<$src_ty>();
assert_eq!(len * std::mem::size_of::<$dst_ty>(), $dst.len());
let src = unsafe { std::slice::from_raw_parts($src.as_ptr() as *const $src_ty, len) };
let dst =
unsafe { std::slice::from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) };
src.iter().zip(dst).for_each(|(src, dst)| *dst = $f(*src));
};
}

match (src_dtype, dst_dtype) {
(DataType::F16, DataType::F16)
| (DataType::BF16, DataType::BF16)
| (DataType::F32, DataType::F32) => dst.copy_from_slice(src),

(DataType::F16, DataType::BF16) => {
cast!(|x: f16| bf16::from_f32(x.to_f32()); src, f16 => dst, bf16);
}
(DataType::F16, DataType::F32) => {
cast!(|x: f16| x.to_f32(); src, f16 => dst, f32);
}
(DataType::BF16, DataType::F16) => {
cast!(|x: bf16| f16::from_f32(x.to_f32()); src, bf16 => dst, f16);
}
(DataType::BF16, DataType::F32) => {
cast!(|x: bf16| x.to_f32(); src, bf16 => dst, f32);
}
(DataType::F32, DataType::F16) => {
cast!(|x: f32| f16::from_f32(x); src, f32 => dst, f16);
}
(DataType::F32, DataType::BF16) => {
cast!(|x: f32| bf16::from_f32(x); src, f32 => dst, bf16);
}
}
}
Loading

0 comments on commit 1c9e761

Please sign in to comment.