Skip to content

Commit

Permalink
feat(model-parameters): 为 llama2 添加一个联合 qkv 的权重张量
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 20, 2024
1 parent d1261a7 commit dd908d0
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 5 deletions.
10 changes: 10 additions & 0 deletions model-parameters/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub trait Llama2 {

fn embed_tokens(&self) -> Tensor<Storage>;
fn input_layernorm(&self, layer: usize) -> Tensor<Storage>;
fn w_qkv(&self, layer: usize) -> Tensor<Storage>;
fn self_attn_q_proj(&self, layer: usize) -> Tensor<Storage>;
fn self_attn_k_proj(&self, layer: usize) -> Tensor<Storage>;
fn self_attn_v_proj(&self, layer: usize) -> Tensor<Storage>;
Expand Down Expand Up @@ -91,6 +92,15 @@ impl Storage {
}
}

#[inline]
pub fn from_blob(data: impl 'static + AsRef<[u8]>) -> Self {
let len = data.as_ref().len();
Self {
data: Arc::new(data),
range: 0..len,
}
}

#[inline]
pub fn as_slice(&self) -> &[u8] {
&self.data.as_ref().as_ref()[self.range.clone()]
Expand Down
4 changes: 1 addition & 3 deletions model-parameters/src/memory/cast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::Layer;
use crate::{ConfigJson, DataType, Llama2, Memory, Storage};
use half::{bf16, f16};
use std::sync::Arc;
use tensor::Tensor;

impl Memory {
Expand Down Expand Up @@ -88,7 +87,6 @@ fn cast(src: Tensor<Storage>, new_dtype: DataType) -> Tensor<Storage> {
_ => todo!(),
}

let len = data.len();
let pysical = Storage::new(Arc::new(data), 0, len);
let pysical = Storage::from_blob(data);
unsafe { src.cast(new_dtype, pysical) }
}
28 changes: 27 additions & 1 deletion model-parameters/src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod safe_tensors;

use crate::{ConfigJson, DataType, Llama2, Storage};
use common::utok;
use tensor::Tensor;
use tensor::{udim, Shape, Tensor};

pub use safe_tensors::SafeTensorError;
pub(crate) use safe_tensors::SafeTensorHeaderJson;
Expand Down Expand Up @@ -99,6 +99,32 @@ impl Llama2 for Memory {
self.layers[layer].input_layernorm.clone()
}

#[inline]
fn w_qkv(&self, layer: usize) -> Tensor<Storage> {
let q = &self.layers[layer].self_attn_q_proj;
let k = &self.layers[layer].self_attn_k_proj;
let v = &self.layers[layer].self_attn_v_proj;
let d = self.hidden_size() as udim;
let dkv =
(self.hidden_size() * self.num_key_value_heads() / self.num_attention_heads()) as udim;
let dt = self.config.torch_dtype.size();
debug_assert_eq!(q.shape(), &[d, d]);
debug_assert_eq!(k.shape(), &[dkv, d]);
debug_assert_eq!(v.shape(), &[dkv, d]);
let size = (q.size() + k.size() + v.size()) * dt;
let mut data = vec![0u8; size];
let (q_, kv_) = data.split_at_mut(q.size() * dt);
let (k_, v_) = kv_.split_at_mut(k.size() * dt);
q_.copy_from_slice(q.physical().as_slice());
k_.copy_from_slice(k.physical().as_slice());
v_.copy_from_slice(v.physical().as_slice());
Tensor::new(
self.config.torch_dtype,
Shape::from_vec(vec![d + dkv + dkv, d]),
Storage::from_blob(data),
)
}

#[inline]
fn self_attn_q_proj(&self, layer: usize) -> Tensor<Storage> {
self.layers[layer].self_attn_q_proj.clone()
Expand Down
13 changes: 13 additions & 0 deletions model-parameters/src/memory/safe_tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,16 @@ pub(crate) struct SafeTensorHeaderJson {
#[serde(rename = "__metadata__")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}

#[test]
fn test_meta() {
let header = SafeTensorHeaderJson {
tensors: HashMap::new(),
meta: Some(
[("concat_qkv".to_string(), serde_json::Value::Bool(true))]
.into_iter()
.collect(),
),
};
println!("{}", serde_json::to_string_pretty(&header).unwrap());
}
2 changes: 1 addition & 1 deletion model-parameters/src/save.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
);
for layer in 0..model.num_hidden_layers() {
header.tensors.insert(
format!("model.layers.{layer}.input_layernorm.weight",),
format!("model.layers.{layer}.input_layernorm.weight"),
TensorInfo {
dtype,
shape: vec![d],
Expand Down

0 comments on commit dd908d0

Please sign in to comment.