Skip to content

Commit

Permalink
feat(tensor): 添加 Split
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 20, 2024
1 parent bba5f80 commit ba4477d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pub type udim = u32;
pub type idim = i32;

pub use data_type::DataType;
pub use operator::{Broadcast, Operator, Slice, Transpose};
pub use operator::{Broadcast, Operator, Slice, Split, Transpose};
pub use tensor::{Affine, Pattern, Shape, Tensor};
2 changes: 2 additions & 0 deletions tensor/src/operator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod broadcast;
mod slice;
mod split;
mod transpose;

use crate::{udim, Affine, Shape};
Expand All @@ -11,4 +12,5 @@ pub trait Operator {

pub use broadcast::Broadcast;
pub use slice::Slice;
pub use split::Split;
pub use transpose::Transpose;
98 changes: 98 additions & 0 deletions tensor/src/operator/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use super::Operator;
use crate::{idim, udim, Affine, Shape};
use smallvec::SmallVec;

pub struct Split {
axis: udim,
segments: Shape,
}

impl Operator for Split {
fn build(&self, input: &[udim]) -> SmallVec<[(Shape, Affine); 1]> {
debug_assert!(self.axis < input.len() as udim);
debug_assert_eq!(input[self.axis as usize], self.segments.iter().sum());

let n = input.len();
let axis = self.axis as usize;
self.segments
.iter()
.scan(0, |prefix, &seg| {
let shape = input
.iter()
.enumerate()
.map(|(i, &dim)| if i == self.axis as usize { seg } else { dim })
.collect();
let affine = Affine::from_fn(n + 1, n + 1, |r, c| {
if r == c {
1
} else if r == n {
if c == axis {
*prefix
} else {
0
}
} else {
0
}
});
*prefix += seg as idim;
Some((shape, affine))
})
.collect()
}
}

#[test]
fn test() {
let ans = Split {
axis: 1,
segments: Shape::from_slice(&[3, 4, 5]),
}
.build(&[11, 12, 13]);
assert_eq!(ans.len(), 3);
assert_eq!(ans[0].0, Shape::from_slice(&[11, 3, 13]));
assert_eq!(ans[1].0, Shape::from_slice(&[11, 4, 13]));
assert_eq!(ans[2].0, Shape::from_slice(&[11, 5, 13]));
assert_eq!(
ans[0].1,
Affine::from_vec(
4,
4,
vec![
// column major
1, 0, 0, 0, //
0, 1, 0, 0, //
0, 0, 1, 0, //
0, 0, 0, 1, //
]
)
);
assert_eq!(
ans[1].1,
Affine::from_vec(
4,
4,
vec![
// column major
1, 0, 0, 0, //
0, 1, 0, 3, //
0, 0, 1, 0, //
0, 0, 0, 1, //
]
)
);
assert_eq!(
ans[2].1,
Affine::from_vec(
4,
4,
vec![
// column major
1, 0, 0, 0, //
0, 1, 0, 7, //
0, 0, 1, 0, //
0, 0, 0, 1, //
]
)
);
}

0 comments on commit ba4477d

Please sign in to comment.