Skip to content

Commit

Permalink
feat: 实现转置变换及测试文档全覆盖
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Sep 7, 2024
1 parent 378938d commit 3e0ec2b
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![doc = include_str!("../README.md")]
#![deny(warnings, missing_docs)]

mod transform;

use std::{
Expand All @@ -9,6 +12,7 @@ use std::{

pub use transform::{IndexArg, SliceArg, Split, TileArg, TileOrder};

/// A tensor layout allow N dimensions inlined.
pub struct TensorLayout<const N: usize = 2> {
order: usize,
content: Union<N>,
Expand Down
9 changes: 6 additions & 3 deletions src/transform/index.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::TensorLayout;
use std::iter::zip;

#[derive(Debug)]
/// 索引变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct IndexArg {
axis: usize,
index: usize,
/// 索引的轴。
pub axis: usize,
/// 选择指定轴的第几个元素。
pub index: usize,
}

impl<const N: usize> TensorLayout<N> {
Expand Down
1 change: 1 addition & 0 deletions src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod merge;
mod slice;
mod split;
mod tile;
mod transpose;

pub use index::IndexArg;
pub use slice::SliceArg;
Expand Down
7 changes: 6 additions & 1 deletion src/transform/slice.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::TensorLayout;
use std::iter::zip;

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
/// 切片变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct SliceArg {
/// 切片的轴。
pub axis: usize,
/// 切片的起始位置。
pub start: usize,
/// 切片的步长。
pub step: isize,
/// 切片的长度。
pub len: usize,
}

Expand Down
1 change: 1 addition & 0 deletions src/transform/split.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::TensorLayout;

/// 切分变换参数。
pub struct Split<'a, const N: usize> {
src: &'a TensorLayout<N>,
axis: usize,
Expand Down
9 changes: 8 additions & 1 deletion src/transform/tile.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use crate::TensorLayout;
use std::iter::zip;

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
/// 分块变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TileArg<'a> {
/// 分块的轴。
pub axis: usize,
/// 分块的顺序。
pub order: TileOrder,
/// 分块的大小。
pub tiles: &'a [usize],
}

/// 分块顺序。
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum TileOrder {
/// 大端分块,分块后范围更大的维度在形状中更靠前的位置。
BigEndian,
/// 小端分块,分块后范围更小的维度在形状中更靠前的位置。
LittleEndian,
}

Expand Down
43 changes: 43 additions & 0 deletions src/transform/transpose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::TensorLayout;
use std::{collections::BTreeSet, iter::zip};

impl<const N: usize> TensorLayout<N> {
/// 转置变换允许调换张量的维度顺序,但不改变元素的存储顺序。
///
/// ```rust
/// # use tensor::TensorLayout;
/// let layout = TensorLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]);
/// assert_eq!(layout.shape(), &[3, 2, 4]);
/// assert_eq!(layout.strides(), &[4, 12, 1]);
/// assert_eq!(layout.offset(), 0);
/// ```
pub fn transpose(&self, perm: &[usize]) -> Self {
let perm_ = perm.iter().collect::<BTreeSet<_>>();
assert_eq!(perm_.len(), perm.len());

let content = self.content();
let shape = content.shape();
let strides = content.strides();

let ans = Self::with_order(self.order);
let content = ans.content();
content.set_offset(self.offset());
let set = |i, j| {
content.set_shape(i, shape[j]);
content.set_stride(i, strides[j]);
};

let mut last = 0;
for (&i, &j) in zip(perm_, perm) {
for i in last..i {
set(i, i);
}
set(i, j);
last = i + 1;
}
for i in last..shape.len() {
set(i, i);
}
ans
}
}

0 comments on commit 3e0ec2b

Please sign in to comment.