Skip to content

Commit

Permalink
reafactor: 区分可变和不可变的 Content 以提升安全性
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Sep 9, 2024
1 parent 3e0ec2b commit 5598bbc
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 37 deletions.
65 changes: 39 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
#![doc = include_str!("../README.md")]
#![deny(warnings, missing_docs)]

mod transform;

use std::{
alloc::{alloc, dealloc, Layout},
iter::zip,
ptr::{copy_nonoverlapping, NonNull},
slice::from_raw_parts,
};

pub use transform::{IndexArg, SliceArg, Split, TileArg, TileOrder};

/// A tensor layout allow N dimensions inlined.
pub struct TensorLayout<const N: usize = 2> {
order: usize,
content: Union<N>,
}

union Union<const N: usize> {
ptr: NonNull<usize>,
_inlined: (usize, [usize; N], [isize; N]),
}

impl<const N: usize> Clone for TensorLayout<N> {
#[inline]
fn clone(&self) -> Self {
Expand Down Expand Up @@ -68,8 +62,8 @@ impl<const N: usize> TensorLayout<N> {
})
.all(|off| off >= 0));

let ans = Self::with_order(shape.len());
let content = ans.content();
let mut ans = Self::with_order(shape.len());
let mut content = ans.content_mut();
content.set_offset(offset);
content.copy_shape(shape);
content.copy_strides(strides);
Expand All @@ -93,7 +87,19 @@ impl<const N: usize> TensorLayout<N> {
pub fn strides(&self) -> &[isize] {
self.content().strides()
}
}

mod transform;
pub use transform::{IndexArg, SliceArg, Split, TileArg, TileOrder};

use std::{
alloc::{alloc, dealloc, Layout},
iter::zip,
ptr::{copy_nonoverlapping, NonNull},
slice::from_raw_parts,
};

impl<const N: usize> TensorLayout<N> {
#[inline]
fn ptr_allocated(&self) -> Option<NonNull<usize>> {
const { assert!(N > 0) }
Expand All @@ -105,7 +111,17 @@ impl<const N: usize> TensorLayout<N> {
}

#[inline]
fn content(&self) -> Content {
fn content(&self) -> Content<false> {
Content {
ptr: self
.ptr_allocated()
.unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
ord: self.order,
}
}

#[inline]
fn content_mut(&mut self) -> Content<true> {
Content {
ptr: self
.ptr_allocated()
Expand All @@ -132,17 +148,12 @@ impl<const N: usize> TensorLayout<N> {
}
}

union Union<const N: usize> {
ptr: NonNull<usize>,
_inlined: (usize, [usize; N], [isize; N]),
}

struct Content {
struct Content<const MUT: bool> {
ptr: NonNull<usize>,
ord: usize,
}

impl Content {
impl Content<false> {
#[inline]
fn as_slice(&self) -> &[usize] {
unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ord * 2) }
Expand All @@ -162,32 +173,34 @@ impl Content {
fn strides<'a>(&self) -> &'a [isize] {
unsafe { from_raw_parts(self.ptr.add(1 + self.ord).cast().as_ptr(), self.ord) }
}
}

impl Content<true> {
#[inline]
fn set_offset(&self, val: usize) {
fn set_offset(&mut self, val: usize) {
unsafe { self.ptr.write(val) }
}

#[inline]
fn set_shape(&self, idx: usize, val: usize) {
fn set_shape(&mut self, idx: usize, val: usize) {
assert!(idx < self.ord);
unsafe { self.ptr.add(1 + idx).write(val) }
}

#[inline]
fn set_stride(&self, idx: usize, val: isize) {
fn set_stride(&mut self, idx: usize, val: isize) {
assert!(idx < self.ord);
unsafe { self.ptr.add(1 + idx + self.ord).cast().write(val) }
}

#[inline]
fn copy_shape(&self, val: &[usize]) {
fn copy_shape(&mut self, val: &[usize]) {
assert!(val.len() == self.ord);
unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ord) }
}

#[inline]
fn copy_strides(&self, val: &[isize]) {
fn copy_strides(&mut self, val: &[isize]) {
assert!(val.len() == self.ord);
unsafe {
copy_nonoverlapping(
Expand Down
4 changes: 2 additions & 2 deletions src/transform/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ impl<const N: usize> TensorLayout<N> {
return self.clone();
}

let ans = Self::with_order(self.order - args.len());
let content = ans.content();
let mut ans = Self::with_order(self.order - args.len());
let mut content = ans.content_mut();
let mut j = 0;
for (i, (&d, &s)) in iter {
match *args {
Expand Down
4 changes: 2 additions & 2 deletions src/transform/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ impl<const N: usize> TensorLayout<N> {
let strides = content.strides();

let merged = args.iter().map(|range| range.len()).sum::<usize>();
let ans = Self::with_order(self.order + args.len() - merged);
let mut ans = Self::with_order(self.order + args.len() - merged);

let content = ans.content();
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut i = 0;
let mut push = |t, s| {
Expand Down
4 changes: 2 additions & 2 deletions src/transform/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ impl<const N: usize> TensorLayout<N> {
let mut offset = content.offset() as isize;
let iter = zip(content.shape(), content.strides()).enumerate();

let ans = Self::with_order(self.order);
let content = ans.content();
let mut ans = Self::with_order(self.order);
let mut content = ans.content_mut();
for (i, (&d, &s)) in iter {
match args {
[arg, tail @ ..] if arg.axis == i => {
Expand Down
4 changes: 2 additions & 2 deletions src/transform/tile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ impl<const N: usize> TensorLayout<N> {
last_axis = arg.axis;
}

let ans = Self::with_order(self.order + new_orders - args.len());
let mut ans = Self::with_order(self.order + new_orders - args.len());

let content = ans.content();
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut j = 0;
let mut push = |t, s| {
Expand Down
6 changes: 3 additions & 3 deletions src/transform/transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ impl<const N: usize> TensorLayout<N> {
let shape = content.shape();
let strides = content.strides();

let ans = Self::with_order(self.order);
let content = ans.content();
let mut ans = Self::with_order(self.order);
let mut content = ans.content_mut();
content.set_offset(self.offset());
let set = |i, j| {
let mut set = |i, j| {
content.set_shape(i, shape[j]);
content.set_stride(i, strides[j]);
};
Expand Down

0 comments on commit 5598bbc

Please sign in to comment.