Skip to content

Commit

Permalink
style(llama): MOE 依赖的 todo! 项分散到各后端实现
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 27, 2024
1 parent 2896798 commit a0116a3
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 25 deletions.
9 changes: 5 additions & 4 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::{
marker::PhantomData,
mem::size_of,
ops::{Deref, Range, RangeBounds},
ptr::copy_nonoverlapping,
slice::{from_raw_parts, from_raw_parts_mut},
};

Expand Down Expand Up @@ -69,7 +70,7 @@ where
where
T: Deref<Target = [ByteOf<Self::Hardware>]>,
{
println!("{tensor}");
println!("{tensor}")
}

fn memcpy_d2h<T: Copy>(
Expand All @@ -79,7 +80,7 @@ where
) {
let count = size_of_val(dst);
assert_eq!(size_of_val(src), count);
unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
unsafe { copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
}
}

Expand Down Expand Up @@ -265,7 +266,7 @@ impl WeightLoader for Weights<'_> {
match which {
AttnQKV => dequant(dt_mat, dt_embd, attn_qkv, &mut cache[..size_qkv]),
AttnO => dequant(dt_mat, dt_embd, attn_o, &mut cache[..size_o]),
FfnGateInp => todo!(),
FfnGateInp => todo!("dequant ffn gate inp"),
FfnGateUp | FfnDown => {
dequant(dt_mat, dt_embd, ffn_gate_up, &mut cache[..size_gate_up]);
dequant(
Expand All @@ -284,7 +285,7 @@ impl WeightLoader for Weights<'_> {
match which {
AttnQKV => 0..size_qkv,
AttnO => 0..size_o,
FfnGateInp => todo!(),
FfnGateInp => todo!("dequant ffn gate inp"),
FfnGateUp => 0..size_gate_up,
FfnDown => size_gate_up..size_gate_up + size_down,
AttnNorm | FfnNorm => unreachable!(),
Expand Down
22 changes: 9 additions & 13 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ pub trait Operators {
T: Deref<Target = [ByteOf<Self::Hardware>]>;

fn memcpy_d2h<T: Copy>(
_dst: &mut [T],
_src: &[ByteOf<Self::Hardware>],
_queue: &QueueOf<Self::Hardware>,
) {
todo!()
}
dst: &mut [T],
src: &[ByteOf<Self::Hardware>],
queue: &QueueOf<Self::Hardware>,
);

fn build_sin_cos<QA>(
dt: DigitLayout,
Expand Down Expand Up @@ -81,13 +79,11 @@ pub trait WeightLoader {

fn load_moe<'a>(
&'a self,
_which: BlkWeight,
_iblk: usize,
_iexp: usize,
_queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a> {
todo!()
}
which: BlkWeight,
iblk: usize,
iexp: usize,
queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a>;

fn output_norm<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
fn output<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
Expand Down
6 changes: 1 addition & 5 deletions models/llama/common/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ impl<'w> BlkStorage<&'w [u8]> {
own(o_.take())
},
ffn_norm: borrow(self.ffn_norm),
ffn_gate_inp: if len == count {
self.ffn_gate_inp.map(borrow)
} else {
todo!()
},
ffn_gate_inp: self.ffn_gate_inp.map(borrow),
ffn_gate_up: if len == count {
borrow(self.ffn_gate_up)
} else {
Expand Down
24 changes: 23 additions & 1 deletion models/llama/infini/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ where
queue.synchronize();
host
});
println!("{tensor}");
println!("{tensor}")
}

fn memcpy_d2h<T: Copy>(
dst: &mut [T],
src: &[ByteOf<Self::Hardware>],
queue: &QueueOf<Self::Hardware>,
) {
queue.get_device().memcpy_d2h(dst, src)
}
}

Expand Down Expand Up @@ -160,6 +168,20 @@ impl WeightLoader for Weights {
}
}

fn load_moe<'a>(
&'a self,
which: BlkWeight,
iblk: usize,
_iexp: usize,
_queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a> {
let _blk = &self.0.blocks[iblk];
match which {
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
_ => unreachable!(),
}
}

#[inline]
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
&self.output_norm
Expand Down
23 changes: 22 additions & 1 deletion models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,15 @@ where
memcpy_d2h(&mut host, s);
host
});
println!("{tensor}");
println!("{tensor}")
}

fn memcpy_d2h<T: Copy>(
dst: &mut [T],
src: &[ByteOf<Self::Hardware>],
_queue: &QueueOf<Self::Hardware>,
) {
memcpy_d2h(dst, src)
}
}

Expand Down Expand Up @@ -331,6 +339,19 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
}
}

fn load_moe<'a>(
&'a self,
which: BlkWeight,
_iblk: usize,
_iexp: usize,
_queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a> {
match which {
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
_ => unreachable!(),
}
}

#[inline]
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
WeightResult::Borrowed(&self.output_norm)
Expand Down
28 changes: 27 additions & 1 deletion models/llama/opencl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use operators::{
use std::{
marker::PhantomData,
ops::{Deref, RangeBounds},
ptr::copy_nonoverlapping,
};

pub struct Operators<N = ClDevice, R = NonAllReduce<ClDevice, Rearrange>>(PhantomData<(N, R)>);
Expand Down Expand Up @@ -49,7 +50,18 @@ where
{
let tensor = tensor.as_ref().map(|s| queue.map(s));
println!("{tensor}");
queue.unmap(tensor.take());
queue.unmap(tensor.take())
}

fn memcpy_d2h<T: Copy>(
dst: &mut [T],
src: &[ByteOf<Self::Hardware>],
queue: &QueueOf<Self::Hardware>,
) {
assert_eq!(size_of_val(dst), size_of_val(src));
let svm = queue.map(src);
unsafe { copy_nonoverlapping(svm.as_ptr(), dst.as_mut_ptr().cast::<u8>(), dst.len()) }
queue.unmap(svm)
}
}

Expand Down Expand Up @@ -122,6 +134,20 @@ impl WeightLoader for Weights {
}
}

fn load_moe<'a>(
&'a self,
which: BlkWeight,
iblk: usize,
_iexp: usize,
_queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a> {
let _blk = &self.0.blocks[iblk];
match which {
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
_ => unreachable!(),
}
}

#[inline]
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
&self.0.output_norm
Expand Down

0 comments on commit a0116a3

Please sign in to comment.