diff --git a/common/src/lib.rs b/common/src/lib.rs index 68a70f19..67a9e0d2 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,3 +1,5 @@ +use std::{fmt, io::Write}; + /// `utok` for token id. #[allow(non_camel_case_types)] pub type utok = u32; @@ -5,3 +7,84 @@ pub type utok = u32; /// `upos` for position id. #[allow(non_camel_case_types)] pub type upos = u32; + +#[macro_export] +macro_rules! slice { + ($blob:expr; $width:expr; [$line:expr]) => { + $blob[$line * $width..][..$width] + }; +} + +pub fn write_tensor( + to: &mut impl Write, + buf: &[T], + shape: &[usize], +) -> std::io::Result<()> { + match shape { + [] => { + writeln!(to, "<>")?; + write_matrix(to, buf, (1, 1)) + } + [len] => { + writeln!(to, "<{len}>")?; + write_matrix(to, buf, (*len, 1)) + } + [rows, cols] => { + writeln!(to, "<{rows}x{cols}>")?; + write_matrix(to, buf, (*rows, *cols)) + } + [batch @ .., rows, cols] => { + let mut strides = vec![1usize; batch.len()]; + for i in (1..batch.len()).rev() { + strides[i - 1] = strides[i] * batch[i]; + } + let strides = strides.as_slice(); + for i in 0..batch[0] * strides[0] { + let mut which = vec![0usize; strides.len()]; + let mut rem = i; + for (j, &stride) in strides.iter().enumerate() { + which[j] = rem / stride; + rem %= stride; + } + writeln!( + to, + "<{rows}x{cols}>[{}]", + which + .iter() + .map(usize::to_string) + .collect::>() + .join(", "), + )?; + write_matrix(to, &slice!(buf; rows * cols; [i]), (*rows, *cols))?; + } + Ok(()) + } + } +} + +fn write_matrix( + to: &mut impl Write, + buf: &[T], + shape: (usize, usize), +) -> std::io::Result<()> { + let (rows, cols) = shape; + for r in 0..rows { + let row = &slice!(buf; cols; [r]); + for it in row { + write!(to, "{it:<8.3e} ")?; + } + writeln!(to)?; + } + Ok(()) +} + +#[test] +fn test_log() { + let array = [ + 1., 2., 3., // + 4., 5., 6., // + 7., 8., 9., // + 10., 11., 12., // + ]; + write_tensor(&mut std::io::stdout(), &array, &[2, 2, 3]).unwrap(); +} diff --git a/model-parameters/Cargo.toml b/model-parameters/Cargo.toml index 11b3c9c4..92c5ea22 100644 --- a/model-parameters/Cargo.toml +++ b/model-parameters/Cargo.toml @@ -9,6 +9,7 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } log = "0.4" +half = "2.3" memmap2 = "0.9" safetensors = "0.4" serde_json = "1.0" diff --git a/model-parameters/src/data_type.rs b/model-parameters/src/data_type.rs new file mode 100644 index 00000000..879c6859 --- /dev/null +++ b/model-parameters/src/data_type.rs @@ -0,0 +1,73 @@ +use serde::{ + de::{Unexpected, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[repr(u8)] +pub enum DataType { + F16, + BF16, + F32, +} + +impl Serialize for DataType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::F16 => serializer.serialize_str("float16"), + Self::BF16 => serializer.serialize_str("bfloat16"), + Self::F32 => serializer.serialize_str("float32"), + } + } +} + +struct DataTypeVisitor; + +impl<'de> Visitor<'de> for DataTypeVisitor { + type Value = DataType; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "pytorch dtype string: \"float16\", \"bfloat16\", or \"float32\"" + ) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "float16" => Ok(DataType::F16), + "bfloat16" => Ok(DataType::BF16), + "float32" => Ok(DataType::F32), + _ => Err(E::invalid_value( + Unexpected::Str(v), + &"\"float16\", \"bfloat16\", or \"float32\"", + )), + } + } +} + +impl<'de> Deserialize<'de> for DataType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(DataTypeVisitor) + } +} + +impl DataType { + #[inline] + pub const fn size(&self) -> usize { + match self { + Self::F16 | Self::BF16 => 2, + Self::F32 => 4, + } + } +} diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs index 971b8317..0591177e 100644 --- a/model-parameters/src/lib.rs +++ b/model-parameters/src/lib.rs @@ -1,23 +1,14 @@ -mod safe_tensors; +mod data_type; +mod memory; -pub enum DataType { - F16, - BF16, - F32, -} +#[macro_use] +extern crate log; -impl DataType { - #[inline] - pub const fn size(&self) -> usize { - match self { - DataType::F16 => 2, - DataType::BF16 => 2, - DataType::F32 => 4, - } - } -} +use common::utok; -pub trait LLama2 { +pub use data_type::DataType; + +pub trait Llama2 { fn hidden_size(&self) -> usize; fn intermediate_size(&self) -> usize; fn max_position_embeddings(&self) -> usize; @@ -41,4 +32,32 @@ pub trait LLama2 { fn lm_head(&self) -> &[u8]; } -pub use safe_tensors::SafeTensors; +pub use memory::{Memory, SafeTensorError}; + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +struct ConfigJson { + pub bos_token_id: utok, + pub eos_token_id: utok, + pub hidden_size: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub vocab_size: usize, + pub rms_norm_eps: f32, + pub rope_theta: f32, + pub torch_dtype: DataType, +} + +struct LayerParamsOffset { + input_layernorm: usize, + self_attn_q_proj: usize, + self_attn_k_proj: usize, + self_attn_v_proj: usize, + self_attn_o_proj: usize, + post_attention_layernorm: usize, + mlp_gate: usize, + mlp_down: usize, + mlp_up: usize, +} diff --git a/model-parameters/src/memory/inside_memory.rs b/model-parameters/src/memory/inside_memory.rs new file mode 100644 index 00000000..b89c12da --- /dev/null +++ b/model-parameters/src/memory/inside_memory.rs @@ -0,0 +1,84 @@ +use crate::{ConfigJson, DataType, LayerParamsOffset, Llama2, Memory}; +use half::{bf16, f16}; + +impl Memory> { + pub fn cast>(src: &Memory, new_dtype: DataType) -> Self { + let mut blob = Vec::new(); + + let from = src.config.torch_dtype; + let mut append = |src: &[u8]| { + let start = blob.len(); + let end = start + src.len() * new_dtype.size() / from.size(); + blob.resize(end, 0); + cast(from, src, new_dtype, &mut blob[start..end]); + start + }; + + let embed_tokens = append(src.embed_tokens()); + let layers = (0..src.config.num_hidden_layers) + .map(|layer| LayerParamsOffset { + input_layernorm: append(src.input_layernorm(layer)), + self_attn_q_proj: append(src.self_attn_q_proj(layer)), + self_attn_k_proj: append(src.self_attn_k_proj(layer)), + self_attn_v_proj: append(src.self_attn_v_proj(layer)), + self_attn_o_proj: append(src.self_attn_o_proj(layer)), + post_attention_layernorm: append(src.post_attention_layernorm(layer)), + mlp_gate: append(src.mlp_gate(layer)), + mlp_down: append(src.mlp_down(layer)), + mlp_up: append(src.mlp_up(layer)), + }) + .collect(); + let model_norm = append(src.model_norm()); + let lm_head = append(src.lm_head()); + + Self { + config: ConfigJson { + torch_dtype: new_dtype, + ..src.config + }, + blob, + embed_tokens, + layers, + model_norm, + lm_head, + } + } +} + +fn cast(src_dtype: DataType, src: &[u8], dst_dtype: DataType, dst: &mut [u8]) { + macro_rules! cast { + ($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => { + let len = $src.len() / std::mem::size_of::<$src_ty>(); + assert_eq!(len * std::mem::size_of::<$dst_ty>(), $dst.len()); + let src = unsafe { std::slice::from_raw_parts($src.as_ptr() as *const $src_ty, len) }; + let dst = + unsafe { std::slice::from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) }; + src.iter().zip(dst).for_each(|(src, dst)| *dst = $f(*src)); + }; + } + + match (src_dtype, dst_dtype) { + (DataType::F16, DataType::F16) + | (DataType::BF16, DataType::BF16) + | (DataType::F32, DataType::F32) => dst.copy_from_slice(src), + + (DataType::F16, DataType::BF16) => { + cast!(|x: f16| bf16::from_f32(x.to_f32()); src, f16 => dst, bf16); + } + (DataType::F16, DataType::F32) => { + cast!(|x: f16| x.to_f32(); src, f16 => dst, f32); + } + (DataType::BF16, DataType::F16) => { + cast!(|x: bf16| f16::from_f32(x.to_f32()); src, bf16 => dst, f16); + } + (DataType::BF16, DataType::F32) => { + cast!(|x: bf16| x.to_f32(); src, bf16 => dst, f32); + } + (DataType::F32, DataType::F16) => { + cast!(|x: f32| f16::from_f32(x); src, f32 => dst, f16); + } + (DataType::F32, DataType::BF16) => { + cast!(|x: f32| bf16::from_f32(x); src, f32 => dst, bf16); + } + } +} diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs new file mode 100644 index 00000000..49140c10 --- /dev/null +++ b/model-parameters/src/memory/mod.rs @@ -0,0 +1,171 @@ +mod inside_memory; +mod safe_tensors; + +use crate::{ConfigJson, DataType, LayerParamsOffset, Llama2}; +pub use safe_tensors::SafeTensorError; + +pub struct Memory { + config: ConfigJson, + blob: T, + embed_tokens: usize, + layers: Vec, + model_norm: usize, + lm_head: usize, +} + +impl> Llama2 for Memory { + #[inline] + fn hidden_size(&self) -> usize { + self.config.hidden_size + } + + #[inline] + fn intermediate_size(&self) -> usize { + self.config.intermediate_size + } + + #[inline] + fn max_position_embeddings(&self) -> usize { + self.config.max_position_embeddings + } + + #[inline] + fn num_attention_heads(&self) -> usize { + self.config.num_attention_heads + } + + #[inline] + fn num_hidden_layers(&self) -> usize { + self.config.num_hidden_layers + } + + #[inline] + fn num_key_value_heads(&self) -> usize { + self.config.num_key_value_heads + } + + #[inline] + fn vocab_size(&self) -> usize { + self.config.vocab_size + } + + #[inline] + fn data_type(&self) -> DataType { + self.config.torch_dtype + } + + #[inline] + fn embed_tokens(&self) -> &[u8] { + let d = self.config.hidden_size; + let dv = self.config.vocab_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.embed_tokens..][..dv * d * dt] + } + + #[inline] + fn input_layernorm(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].input_layernorm..][..d * dt] + } + + #[inline] + fn self_attn_q_proj(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].self_attn_q_proj..][..d * d * dt] + } + + #[inline] + fn self_attn_k_proj(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].self_attn_k_proj..][..dkv * d * dt] + } + + #[inline] + fn self_attn_v_proj(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].self_attn_v_proj..][..dkv * d * dt] + } + + #[inline] + fn self_attn_o_proj(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].self_attn_o_proj..][..d * d * dt] + } + + #[inline] + fn post_attention_layernorm(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].post_attention_layernorm..][..d * dt] + } + + #[inline] + fn mlp_gate(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let di = self.config.intermediate_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].mlp_gate..][..di * d * dt] + } + + #[inline] + fn mlp_down(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let di = self.config.intermediate_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].mlp_down..][..d * di * dt] + } + + #[inline] + fn mlp_up(&self, layer: usize) -> &[u8] { + let d = self.config.hidden_size; + let di = self.config.intermediate_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.layers[layer].mlp_up..][..di * d * dt] + } + + #[inline] + fn model_norm(&self) -> &[u8] { + let d = self.config.hidden_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.model_norm..][..d * dt] + } + + #[inline] + fn lm_head(&self) -> &[u8] { + let d = self.config.hidden_size; + let dv: usize = self.config.vocab_size; + let dt: usize = self.data_type().size(); + &self.blob.as_ref()[self.lm_head..][..dv * d * dt] + } +} + +#[test] +fn test_load() { + use std::time::Instant; + + // set env for POWERSHELL: `$env:RUST_LOG="INFO";` + env_logger::init(); + + let t0 = Instant::now(); + let safetensors = Memory::load_safetensors("../../TinyLlama-1.1B-Chat-v1.0"); + let t1 = Instant::now(); + println!("mmap {:?}", t1 - t0); + + let safetensors = match safetensors { + Ok(m) => m, + Err(SafeTensorError::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => return, + Err(e) => panic!("{e:?}"), + }; + + let t0 = Instant::now(); + let _inside_memory = Memory::cast(&safetensors, DataType::F32); + let t1 = Instant::now(); + println!("cast {:?}", t1 - t0); +} diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs new file mode 100644 index 00000000..c8f02a14 --- /dev/null +++ b/model-parameters/src/memory/safe_tensors.rs @@ -0,0 +1,151 @@ +use super::Memory; +use crate::{ConfigJson, DataType, LayerParamsOffset}; +use memmap2::Mmap; +use safetensors::{tensor::TensorInfo, Dtype}; +use std::{collections::HashMap, fs::File, path::Path}; + +#[derive(Debug)] +pub enum SafeTensorError { + Io(std::io::Error), + Serde(serde_json::Error), +} + +impl Memory { + pub fn load_safetensors(model_dir: impl AsRef) -> Result { + let dir = model_dir.as_ref(); + let config = File::open(dir.join("config.json")).map_err(SafeTensorError::Io)?; + let model = File::open(dir.join("model.safetensors")).map_err(SafeTensorError::Io)?; + + let config: ConfigJson = serde_json::from_reader(config).map_err(SafeTensorError::Serde)?; + let dtype = match config.torch_dtype { + DataType::F16 => Dtype::F16, + DataType::BF16 => Dtype::BF16, + DataType::F32 => Dtype::F32, + }; + + let mmap = unsafe { Mmap::map(&model) }.map_err(SafeTensorError::Io)?; + let len = unsafe { *mmap.as_ptr().cast::() } as usize; + const BASE_OFFSET: usize = std::mem::size_of::(); + let header = &mmap[BASE_OFFSET..][..len]; + let header: SafeTensorHeaderJson = + serde_json::from_slice(header).map_err(SafeTensorError::Serde)?; + + let d = config.hidden_size; + let kv_dim = d * config.num_key_value_heads / config.num_attention_heads; + let di = config.intermediate_size; + + let mut embed_tokens = 0; + let mut layers = (0..config.num_hidden_layers) + .map(|_| LayerParamsOffset { + input_layernorm: 0, + self_attn_q_proj: 0, + self_attn_k_proj: 0, + self_attn_v_proj: 0, + self_attn_o_proj: 0, + post_attention_layernorm: 0, + mlp_gate: 0, + mlp_down: 0, + mlp_up: 0, + }) + .collect::>(); + let mut model_norm = 0; + let mut lm_head = 0; + + let header_offset = BASE_OFFSET + len; + for (name, tensor) in header.tensors { + let path = name.split('.').collect::>(); + let offset = header_offset + tensor.data_offsets.0; + + info!(target: "import safetensors", "detect {offset:#010x} -> \"{name}\""); + match path.as_slice() { + ["model", "embed_tokens", "weight"] => { + assert_eq!(&tensor.shape, &[config.vocab_size, d]); + assert_eq!(tensor.dtype, dtype); + embed_tokens = offset; + } + ["model", "layers", n, path @ .., "weight"] => { + let layer = n.parse::().unwrap(); + + match path { + ["input_layernorm"] => { + assert_eq!(&tensor.shape, &[d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].input_layernorm = offset; + } + ["self_attn", "q_proj"] => { + assert_eq!(&tensor.shape, &[d, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].self_attn_q_proj = offset; + } + ["self_attn", "k_proj"] => { + assert_eq!(&tensor.shape, &[kv_dim, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].self_attn_k_proj = offset; + } + ["self_attn", "v_proj"] => { + assert_eq!(&tensor.shape, &[kv_dim, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].self_attn_v_proj = offset; + } + ["self_attn", "o_proj"] => { + assert_eq!(&tensor.shape, &[d, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].self_attn_o_proj = offset; + } + ["post_attention_layernorm"] => { + assert_eq!(&tensor.shape, &[d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].post_attention_layernorm = offset; + } + ["mlp", "gate_proj"] => { + assert_eq!(&tensor.shape, &[di, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].mlp_gate = offset; + } + ["mlp", "down_proj"] => { + assert_eq!(&tensor.shape, &[d, di]); + assert_eq!(tensor.dtype, dtype); + layers[layer].mlp_down = offset; + } + ["mlp", "up_proj"] => { + assert_eq!(&tensor.shape, &[di, d]); + assert_eq!(tensor.dtype, dtype); + layers[layer].mlp_up = offset; + } + [..] => { + warn!(target: "import safetensors", "Unknown tensor path: \"{name}\"") + } + }; + } + ["model", "norm", "weight"] => { + assert_eq!(&tensor.shape, &[d]); + assert_eq!(tensor.dtype, dtype); + model_norm = offset; + } + ["lm_head", "weight"] => { + assert_eq!(&tensor.shape, &[config.vocab_size, d]); + assert_eq!(tensor.dtype, dtype); + lm_head = offset; + } + [..] => warn!(target: "import safetensors", "Unknown tensor path: \"{name}\""), + } + } + + Ok(Self { + config, + blob: mmap, + embed_tokens, + layers, + model_norm, + lm_head, + }) + } +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +struct SafeTensorHeaderJson { + #[serde(flatten)] + pub tensors: HashMap, + #[serde(rename = "__metadata__")] + pub meta: Option>, +} diff --git a/model-parameters/src/safe_tensors/config.rs b/model-parameters/src/safe_tensors/config.rs deleted file mode 100644 index 17ed16b0..00000000 --- a/model-parameters/src/safe_tensors/config.rs +++ /dev/null @@ -1,27 +0,0 @@ -use common::utok; -use safetensors::tensor::TensorInfo; -use std::collections::HashMap; - -#[derive(serde::Serialize, serde::Deserialize, Debug)] -pub(crate) struct ConfigJson { - pub bos_token_id: utok, - pub eos_token_id: utok, - pub hidden_size: usize, - pub intermediate_size: usize, - pub max_position_embeddings: usize, - pub num_attention_heads: usize, - pub num_hidden_layers: usize, - pub num_key_value_heads: usize, - pub vocab_size: usize, - pub rms_norm_eps: f32, - pub rope_theta: f32, - pub torch_dtype: String, -} - -#[derive(serde::Serialize, serde::Deserialize, Debug)] -pub(crate) struct SafeTensorHeaderJson { - #[serde(flatten)] - pub tensors: HashMap, - #[serde(rename = "__metadata__")] - pub meta: Option>, -} diff --git a/model-parameters/src/safe_tensors/mod.rs b/model-parameters/src/safe_tensors/mod.rs deleted file mode 100644 index 6a81a624..00000000 --- a/model-parameters/src/safe_tensors/mod.rs +++ /dev/null @@ -1,331 +0,0 @@ -mod config; - -use crate::{DataType, LLama2}; -use config::{ConfigJson, SafeTensorHeaderJson}; -use log::{info, warn}; -use memmap2::Mmap; -use safetensors::Dtype; -use std::{fs::File, path::Path}; - -pub struct SafeTensors { - config: ConfigJson, - mmap: Mmap, - embed_tokens: usize, - layers: Vec, - model_norm: usize, - lm_head: usize, -} - -struct LayerParamsOffset { - input_layernorm: usize, - self_attn_q_proj: usize, - self_attn_k_proj: usize, - self_attn_v_proj: usize, - self_attn_o_proj: usize, - post_attention_layernorm: usize, - mlp_gate: usize, - mlp_down: usize, - mlp_up: usize, -} - -macro_rules! slice { - ($mmap:expr; $offset:expr, $len:expr) => { - $mmap[$offset..][..$len] - }; -} - -impl LLama2 for SafeTensors { - #[inline] - fn hidden_size(&self) -> usize { - self.config.hidden_size - } - - #[inline] - fn intermediate_size(&self) -> usize { - self.config.intermediate_size - } - - #[inline] - fn max_position_embeddings(&self) -> usize { - self.config.max_position_embeddings - } - - #[inline] - fn num_attention_heads(&self) -> usize { - self.config.num_attention_heads - } - - #[inline] - fn num_hidden_layers(&self) -> usize { - self.config.num_hidden_layers - } - - #[inline] - fn num_key_value_heads(&self) -> usize { - self.config.num_key_value_heads - } - - #[inline] - fn vocab_size(&self) -> usize { - self.config.vocab_size - } - - #[inline] - fn data_type(&self) -> DataType { - match self.config.torch_dtype.as_str() { - "float16" => DataType::F16, - "bfloat16" => DataType::BF16, - "float32" => DataType::F32, - t => panic!("Unsupported dtype: \"{t}\""), - } - } - - #[inline] - fn embed_tokens(&self) -> &[u8] { - let d = self.config.hidden_size; - let dv = self.config.vocab_size; - let dt = self.data_type().size(); - &slice!(self.mmap; self.embed_tokens, dv * d * dt) - } - - #[inline] - fn input_layernorm(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].input_layernorm, d * dt) - } - - #[inline] - fn self_attn_q_proj(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].self_attn_q_proj, d * d * dt) - } - - #[inline] - fn self_attn_k_proj(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].self_attn_k_proj, dkv * d * dt) - } - - #[inline] - fn self_attn_v_proj(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].self_attn_v_proj, dkv * d * dt) - } - - #[inline] - fn self_attn_o_proj(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].self_attn_o_proj, d * d * dt) - } - - #[inline] - fn post_attention_layernorm(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].post_attention_layernorm, d * dt) - } - - #[inline] - fn mlp_gate(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].mlp_gate, di * d * dt) - } - - #[inline] - fn mlp_down(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].mlp_down, d * di * dt) - } - - #[inline] - fn mlp_up(&self, layer: usize) -> &[u8] { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.layers[layer].mlp_up, di * d * dt) - } - - #[inline] - fn model_norm(&self) -> &[u8] { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.model_norm, d * dt) - } - - #[inline] - fn lm_head(&self) -> &[u8] { - let d = self.config.hidden_size; - let dv: usize = self.config.vocab_size; - let dt: usize = self.data_type().size(); - &slice!(self.mmap; self.lm_head, dv * d * dt) - } -} - -#[derive(Debug)] -pub enum SafeTensorError { - Io(std::io::Error), - Serde(serde_json::Error), -} - -impl SafeTensors { - pub fn new(model_dir: impl AsRef) -> Result { - let dir = model_dir.as_ref(); - let config = File::open(dir.join("config.json")).map_err(SafeTensorError::Io)?; - let model = File::open(dir.join("model.safetensors")).map_err(SafeTensorError::Io)?; - - let config: ConfigJson = serde_json::from_reader(config).map_err(SafeTensorError::Serde)?; - let dtype = match config.torch_dtype.as_str() { - "float16" => Dtype::F16, - "bfloat16" => Dtype::BF16, - "float32" => Dtype::F32, - _ => panic!("Unsupported dtype: {}", config.torch_dtype), - }; - - let mmap = unsafe { Mmap::map(&model) }.map_err(SafeTensorError::Io)?; - let len = unsafe { *mmap.as_ptr().cast::() } as usize; - const BASE_OFFSET: usize = std::mem::size_of::(); - let header = &mmap[BASE_OFFSET..][..len]; - let header: SafeTensorHeaderJson = - serde_json::from_slice(header).map_err(SafeTensorError::Serde)?; - - let d = config.hidden_size; - let kv_dim = d * config.num_key_value_heads / config.num_attention_heads; - let di = config.intermediate_size; - - let mut embed_tokens = 0; - let mut layers = (0..config.num_hidden_layers) - .map(|_| LayerParamsOffset { - input_layernorm: 0, - self_attn_q_proj: 0, - self_attn_k_proj: 0, - self_attn_v_proj: 0, - self_attn_o_proj: 0, - post_attention_layernorm: 0, - mlp_gate: 0, - mlp_down: 0, - mlp_up: 0, - }) - .collect::>(); - let mut model_norm = 0; - let mut lm_head = 0; - - let header_offset = BASE_OFFSET + len; - for (name, tensor) in header.tensors { - let path = name.split('.').collect::>(); - let offset = header_offset + tensor.data_offsets.0; - - info!(target: "import safetensors", "detect {offset:#010x} -> \"{name}\""); - match path.as_slice() { - ["model", "embed_tokens", "weight"] => { - assert_eq!(&tensor.shape, &[config.vocab_size, d]); - assert_eq!(tensor.dtype, dtype); - embed_tokens = offset; - } - ["model", "layers", n, path @ .., "weight"] => { - let layer = n.parse::().unwrap(); - - match path { - ["input_layernorm"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].input_layernorm = offset; - } - ["self_attn", "q_proj"] => { - assert_eq!(&tensor.shape, &[d, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_q_proj = offset; - } - ["self_attn", "k_proj"] => { - assert_eq!(&tensor.shape, &[kv_dim, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_k_proj = offset; - } - ["self_attn", "v_proj"] => { - assert_eq!(&tensor.shape, &[kv_dim, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_v_proj = offset; - } - ["self_attn", "o_proj"] => { - assert_eq!(&tensor.shape, &[d, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_o_proj = offset; - } - ["post_attention_layernorm"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].post_attention_layernorm = offset; - } - ["mlp", "gate_proj"] => { - assert_eq!(&tensor.shape, &[di, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_gate = offset; - } - ["mlp", "down_proj"] => { - assert_eq!(&tensor.shape, &[d, di]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_down = offset; - } - ["mlp", "up_proj"] => { - assert_eq!(&tensor.shape, &[di, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_up = offset; - } - [..] => { - warn!(target: "import safetensors", "Unknown tensor path: \"{name}\"") - } - }; - } - ["model", "norm", "weight"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - model_norm = offset; - } - ["lm_head", "weight"] => { - assert_eq!(&tensor.shape, &[config.vocab_size, d]); - assert_eq!(tensor.dtype, dtype); - lm_head = offset; - } - [..] => warn!(target: "import safetensors", "Unknown tensor path: \"{name}\""), - } - } - - Ok(Self { - config, - mmap, - embed_tokens, - layers, - model_norm, - lm_head, - }) - } -} - -#[test] -fn test_load() { - use std::time::Instant; - - // set env for POWERSHELL: `$env:RUST_LOG="INFO";` - env_logger::init(); - - let t0 = Instant::now(); - let safetensors = SafeTensors::new("../../TinyLlama-1.1B-Chat-v1.0"); - let t1 = Instant::now(); - println!("{:?}", t1 - t0); - - match safetensors { - Ok(_) => {} - Err(SafeTensorError::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => {} - Err(e) => panic!("{e:?}"), - } -}