From dd908d07d38b68ed6a04c58338770d23bf918502 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 20 Feb 2024 15:25:11 +0800 Subject: [PATCH] =?UTF-8?q?feat(model-parameters):=20=E4=B8=BA=20llama2=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=B8=AA=E8=81=94=E5=90=88=20qkv=20?= =?UTF-8?q?=E7=9A=84=E6=9D=83=E9=87=8D=E5=BC=A0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- model-parameters/src/lib.rs | 10 ++++++++ model-parameters/src/memory/cast.rs | 4 +-- model-parameters/src/memory/mod.rs | 28 ++++++++++++++++++++- model-parameters/src/memory/safe_tensors.rs | 13 ++++++++++ model-parameters/src/save.rs | 2 +- 5 files changed, 52 insertions(+), 5 deletions(-) diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs index 8e577ca1..f83b8f2a 100644 --- a/model-parameters/src/lib.rs +++ b/model-parameters/src/lib.rs @@ -46,6 +46,7 @@ pub trait Llama2 { fn embed_tokens(&self) -> Tensor; fn input_layernorm(&self, layer: usize) -> Tensor; + fn w_qkv(&self, layer: usize) -> Tensor; fn self_attn_q_proj(&self, layer: usize) -> Tensor; fn self_attn_k_proj(&self, layer: usize) -> Tensor; fn self_attn_v_proj(&self, layer: usize) -> Tensor; @@ -91,6 +92,15 @@ impl Storage { } } + #[inline] + pub fn from_blob(data: impl 'static + AsRef<[u8]>) -> Self { + let len = data.as_ref().len(); + Self { + data: Arc::new(data), + range: 0..len, + } + } + #[inline] pub fn as_slice(&self) -> &[u8] { &self.data.as_ref().as_ref()[self.range.clone()] diff --git a/model-parameters/src/memory/cast.rs b/model-parameters/src/memory/cast.rs index 01c39077..caeb9bb2 100644 --- a/model-parameters/src/memory/cast.rs +++ b/model-parameters/src/memory/cast.rs @@ -1,7 +1,6 @@ use super::Layer; use crate::{ConfigJson, DataType, Llama2, Memory, Storage}; use half::{bf16, f16}; -use std::sync::Arc; use tensor::Tensor; impl Memory { @@ -88,7 +87,6 @@ fn cast(src: Tensor, new_dtype: DataType) -> Tensor { _ => todo!(), } - let len = data.len(); - let pysical = Storage::new(Arc::new(data), 0, len); + let pysical = Storage::from_blob(data); unsafe { src.cast(new_dtype, pysical) } } diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs index 68916205..a0bc2e4d 100644 --- a/model-parameters/src/memory/mod.rs +++ b/model-parameters/src/memory/mod.rs @@ -3,7 +3,7 @@ mod safe_tensors; use crate::{ConfigJson, DataType, Llama2, Storage}; use common::utok; -use tensor::Tensor; +use tensor::{udim, Shape, Tensor}; pub use safe_tensors::SafeTensorError; pub(crate) use safe_tensors::SafeTensorHeaderJson; @@ -99,6 +99,32 @@ impl Llama2 for Memory { self.layers[layer].input_layernorm.clone() } + #[inline] + fn w_qkv(&self, layer: usize) -> Tensor { + let q = &self.layers[layer].self_attn_q_proj; + let k = &self.layers[layer].self_attn_k_proj; + let v = &self.layers[layer].self_attn_v_proj; + let d = self.hidden_size() as udim; + let dkv = + (self.hidden_size() * self.num_key_value_heads() / self.num_attention_heads()) as udim; + let dt = self.config.torch_dtype.size(); + debug_assert_eq!(q.shape(), &[d, d]); + debug_assert_eq!(k.shape(), &[dkv, d]); + debug_assert_eq!(v.shape(), &[dkv, d]); + let size = (q.size() + k.size() + v.size()) * dt; + let mut data = vec![0u8; size]; + let (q_, kv_) = data.split_at_mut(q.size() * dt); + let (k_, v_) = kv_.split_at_mut(k.size() * dt); + q_.copy_from_slice(q.physical().as_slice()); + k_.copy_from_slice(k.physical().as_slice()); + v_.copy_from_slice(v.physical().as_slice()); + Tensor::new( + self.config.torch_dtype, + Shape::from_vec(vec![d + dkv + dkv, d]), + Storage::from_blob(data), + ) + } + #[inline] fn self_attn_q_proj(&self, layer: usize) -> Tensor { self.layers[layer].self_attn_q_proj.clone() diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs index 5c491a4a..40bf91e9 100644 --- a/model-parameters/src/memory/safe_tensors.rs +++ b/model-parameters/src/memory/safe_tensors.rs @@ -84,3 +84,16 @@ pub(crate) struct SafeTensorHeaderJson { #[serde(rename = "__metadata__")] pub meta: Option>, } + +#[test] +fn test_meta() { + let header = SafeTensorHeaderJson { + tensors: HashMap::new(), + meta: Some( + [("concat_qkv".to_string(), serde_json::Value::Bool(true))] + .into_iter() + .collect(), + ), + }; + println!("{}", serde_json::to_string_pretty(&header).unwrap()); +} diff --git a/model-parameters/src/save.rs b/model-parameters/src/save.rs index 6b5ef011..2aee65fd 100644 --- a/model-parameters/src/save.rs +++ b/model-parameters/src/save.rs @@ -62,7 +62,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef) -> io::Result<()> { ); for layer in 0..model.num_hidden_layers() { header.tensors.insert( - format!("model.layers.{layer}.input_layernorm.weight",), + format!("model.layers.{layer}.input_layernorm.weight"), TensorInfo { dtype, shape: vec![d],