Skip to content

Commit

Permalink
refactor: offset 改为 isize,并将地址范围控制移到独立函数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 4, 2024
1 parent 68ab1db commit 48d36c5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
41 changes: 26 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct ArrayLayout<const N: usize = 2> {

union Union<const N: usize> {
ptr: NonNull<usize>,
_inlined: (usize, [usize; N], [isize; N]),
_inlined: (isize, [usize; N], [isize; N]),
}

impl<const N: usize> Clone for ArrayLayout<N> {
Expand Down Expand Up @@ -55,21 +55,13 @@ impl<const N: usize> ArrayLayout<N> {
/// assert_eq!(layout.shape(), &[2, 3, 4]);
/// assert_eq!(layout.strides(), &[12, -4, 1]);
/// ```
pub fn new(shape: &[usize], strides: &[isize], offset: usize) -> Self {
pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
// check
assert_eq!(
shape.len(),
strides.len(),
"shape and strides must have the same length"
);
assert!(
zip(shape, strides)
.filter(|(_, &s)| s < 0)
.fold(offset as isize, |offset, (&d, &s)| {
offset + (d - 1) as isize * s
})
>= 0
);

let mut ans = Self::with_ndim(shape.len());
let mut content = ans.content_mut();
Expand Down Expand Up @@ -113,7 +105,7 @@ impl<const N: usize> ArrayLayout<N> {

/// Gets offset.
#[inline]
pub fn offset(&self) -> usize {
pub fn offset(&self) -> isize {
self.content().offset()
}

Expand All @@ -128,14 +120,33 @@ impl<const N: usize> ArrayLayout<N> {
pub fn strides(&self) -> &[isize] {
self.content().strides()
}

/// Calculate the range of data in bytes to determine the location of the memory area that the tensor needs to access.
pub fn data_range(&self) -> RangeInclusive<isize> {
let content = self.content();
let mut start = content.offset();
let mut end = content.offset();
for (&d, s) in zip(content.shape(), content.strides()) {
use std::cmp::Ordering::{Equal, Greater, Less};
let i = d as isize - 1;
match s.cmp(&0) {
Equal => {}
Less => start += s * i,
Greater => end += s * i,
}
}
start..=end
}
}

mod transform;
pub use transform::{IndexArg, SliceArg, Split, TileArg};

use std::{
alloc::{alloc, dealloc, Layout},
isize,
iter::zip,
ops::RangeInclusive,
ptr::{copy_nonoverlapping, NonNull},
slice::from_raw_parts,
};
Expand Down Expand Up @@ -201,8 +212,8 @@ impl Content<false> {
}

#[inline]
fn offset(&self) -> usize {
unsafe { self.ptr.read() }
fn offset(&self) -> isize {
unsafe { self.ptr.cast().read() }
}

#[inline]
Expand All @@ -218,8 +229,8 @@ impl Content<false> {

impl Content<true> {
#[inline]
fn set_offset(&mut self, val: usize) {
unsafe { self.ptr.write(val) }
fn set_offset(&mut self, val: isize) {
unsafe { self.ptr.cast().write(val) }
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/transform/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<const N: usize> ArrayLayout<N> {
/// 一次对多个阶进行索引变换。
pub fn index_many(&self, mut args: &[IndexArg]) -> Self {
let content = self.content();
let mut offset = content.offset() as isize;
let mut offset = content.offset();
let shape = content.shape();
let iter = zip(shape, content.strides()).enumerate();

Expand Down
2 changes: 1 addition & 1 deletion src/transform/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<const N: usize> ArrayLayout<N> {
/// 一次对多个阶进行切片变换。
pub fn slice_many(&self, mut args: &[SliceArg]) -> Self {
let content = self.content();
let mut offset = content.offset() as isize;
let mut offset = content.offset();
let iter = zip(content.shape(), content.strides()).enumerate();

let mut ans = Self::with_ndim(self.ndim);
Expand Down

0 comments on commit 48d36c5

Please sign in to comment.