Skip to content

Commit

Permalink
feat: 为合并变换提供检查
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Sep 26, 2024
1 parent 3de35cc commit 5c6b969
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions src/transform/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ impl<const N: usize> ArrayLayout<N> {
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge(0..3);
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge(0..3).unwrap();
/// assert_eq!(layout.shape(), &[24]);
/// assert_eq!(layout.strides(), &[1]);
/// assert_eq!(layout.offset(), 0);
/// ```
#[inline]
pub fn merge(&self, range: Range<usize>) -> Self {
pub fn merge(&self, range: Range<usize>) -> Option<Self> {
self.merge_many(&[range])
}

/// 一次对多个阶进行合并变换。
pub fn merge_many(&self, args: &[Range<usize>]) -> Self {
pub fn merge_many(&self, args: &[Range<usize>]) -> Option<Self> {
let content = self.content();
let shape = content.shape();
let strides = content.strides();
Expand All @@ -28,34 +28,46 @@ impl<const N: usize> ArrayLayout<N> {
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut i = 0;
let mut push = |t, s| {
content.set_shape(i, t);
let mut push = |d, s| {
content.set_shape(i, d);
content.set_stride(i, s);
i += 1;
};

let mut last_end = 0;
for range in args {
assert!(!range.is_empty());
if range.is_empty() {
continue;
}

assert!(range.start >= last_end);
for j in last_end..range.start {
push(shape[j], strides[j]);
}

let mut pairs = zip(&shape[range.clone()], &strides[range.clone()]).collect::<Vec<_>>();
pairs.sort_unstable_by_key(|(_, &s)| s.unsigned_abs());
assert!(pairs.windows(2).all(|slice| {
let &[(&d, &s), (_, &s_)] = slice else {
unreachable!()
};
s_ == s * d as isize
}));
push(pairs.iter().map(|(d, _)| *d).product(), *pairs[0].1);

let (&d, &s) = pairs[0];
let mut d = d;

for i in 1..pairs.len() {
let (&l, &ls) = pairs[i - 1];
let (&r, &rs) = pairs[i];
if l == 1 || s == 1 || ls == rs * r as isize || rs == ls * l as isize {
d *= r;
} else {
return None;
}
}

push(d, s);
last_end = range.end;
}
for j in last_end..shape.len() {
push(shape[j], strides[j]);
}

ans
Some(ans)
}
}

0 comments on commit 5c6b969

Please sign in to comment.