Skip to content

Commit

Permalink
Merge pull request #55 from scalableminds/improve-multi-channel-write…
Browse files Browse the repository at this point in the history
…-performance

Improve multi channel write performance
  • Loading branch information
rschwanhold authored Oct 30, 2020
2 parents c29a789 + cf892a7 commit b4f5262
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 19 deletions.
15 changes: 15 additions & 0 deletions python/tests/test_wkw.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,21 @@ def test_multiple_writes_and_reads():
assert np.all(mem_buffer == read_data)


def test_multi_channel_column_major_order():

with wkw.Dataset.create(
"tests/tmp", wkw.Header(np.uint8, num_channels=3)
) as dataset:
offset = (30, 20, 10)
data_shape = (3, 100, 200, 300)
order = "C"
data = generate_test_data(np.uint8, list(data_shape), order)
dataset.write(offset, data)

read_data = dataset.read(offset, data_shape[1:])
assert np.all(data == read_data)


def test_big_read():
data = np.ones((10, 10, 764), order="C", dtype=np.uint8)
offset = np.array([0, 0, 640])
Expand Down
3 changes: 1 addition & 2 deletions python/wkw/wkw.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ def write(self, off, data):
def is_contiguous(data):
return data.flags["F_CONTIGUOUS"] or data.flags["C_CONTIGUOUS"]

# the row-major handling of the rust lib cannot handle num_channels > 1
if self.header.num_channels != 1 or not is_contiguous(data):
if not is_contiguous(data):
data = np.asfortranarray(data)

box = _build_box(off, data.shape[-3:])
Expand Down
5 changes: 0 additions & 5 deletions rust/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ impl Dataset {
return Err(format!("Input matrix has invalid voxel size {} != {}", mat.voxel_size, self.header.voxel_size as usize));
}

let num_channels = self.header.voxel_type.size() / self.header.voxel_size as usize;
if num_channels > 1 && mat.data_in_c_order {
return Err(String::from("Cannot write multichannel data if data is in row-major order."));
}

let file_len_vx_log2 = self.header.file_len_vx_log2() as u32;
if self.header.block_type == BlockType::LZ4 || self.header.block_type == BlockType::LZ4HC {
let file_len_vec = Vec3::from(1 << file_len_vx_log2);
Expand Down
93 changes: 81 additions & 12 deletions rust/src/mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ pub struct Mat<'a> {
pub data_in_c_order: bool,
}

pub fn linearize(channel: usize, x: usize, y: usize, z: usize, stride: &[usize]) -> isize {
(channel * stride[0] + x * stride[1] + y * stride[2] + z * stride[3]) as isize
}

impl<'a> Mat<'a> {
pub fn new(
data: &mut [u8],
Expand Down Expand Up @@ -79,38 +83,40 @@ impl<'a> Mat<'a> {

let x_length = self.shape.x as usize;
let y_length = self.shape.y as usize;
let z_length = self.shape.z as usize;
let num_channel = self.voxel_size / self.voxel_type.size();
let item_size = self.voxel_size / num_channel;

let row_major_stride: Vec<usize> = vec![
z_length * y_length * self.voxel_size,
z_length * self.voxel_size,
item_size,
y_length * x_length * self.voxel_size,
y_length * self.voxel_size,
self.voxel_size,
];

let column_major_stride: Vec<usize> = vec![
item_size,
self.voxel_size,
x_length * self.voxel_size,
x_length * y_length * self.voxel_size,
];

fn linearize(x: usize, y: usize, z: usize, stride: &[usize]) -> isize {
(x * stride[0] + y * stride[1] + z * stride[2]) as isize
}
let src_ptr = self.data.as_ptr();
let dst_ptr = buffer_data.as_mut_ptr();

let from = src_bbox.min();
let to = src_bbox.max();

// Do continuous read in z. Last dim in Row-Major is continuous.
let stripe_len = item_size * num_channel;
for x in from.x as usize..to.x as usize {
for y in from.y as usize..to.y as usize {
for z in from.z as usize..to.z as usize {
let row_major_index = linearize(x, y, z, &row_major_stride);
let column_major_index = linearize(x, y, z, &column_major_stride);
let row_major_index = linearize(0, x, y, z, &row_major_stride);
let column_major_index = linearize(0, x, y, z, &column_major_stride);
unsafe {
let cur_src_ptr = src_ptr.offset(row_major_index);
let cur_dst_ptr = dst_ptr.offset(column_major_index);
ptr::copy_nonoverlapping(cur_src_ptr, cur_dst_ptr, self.voxel_size);
ptr::copy_nonoverlapping(src_ptr.offset(row_major_index), dst_ptr.offset(column_major_index), stripe_len);
}

}
}
}
Expand All @@ -129,7 +135,14 @@ impl<'a> Mat<'a> {
}

if src.data_in_c_order {
intermediate_buffer.copy_from(dst_pos, src, src_box)?;
let num_channel = self.voxel_size / self.voxel_type.size();
if num_channel == 1 {
// if the data has only one channel, copy_from is a bit faster because it copies more items simultaneously
intermediate_buffer.copy_from(dst_pos, src, src_box)?;
} else {
// putting the channels to the back avoids that the indices (in copy_as_fortran_order) make too big jumps
intermediate_buffer.copy_from_and_put_channels_last(dst_pos, src, src_box)?;
}
let dst_bbox = Box3::new(dst_pos, dst_pos + src_box.width())?;
intermediate_buffer.copy_as_fortran_order(self, dst_bbox)
} else {
Expand Down Expand Up @@ -208,4 +221,60 @@ impl<'a> Mat<'a> {
}
Ok(())
}

pub fn copy_from_and_put_channels_last(&mut self, dst_pos: Vec3, src: &Mat, src_box: Box3) -> Result<()> {
// make sure that matrices are matching
if self.voxel_size != src.voxel_size {
return Err(format!("Matrices mismatch in voxel size {} != {}", self.voxel_size, src.voxel_size));
}
if self.voxel_type != src.voxel_type {
return Err(format!("Matrices mismatch in voxel type {:?} != {:?}", self.voxel_type, src.voxel_type));
}
if !(src_box.max() < (src.shape + 1)) {
return Err(String::from("Reading out of bounds"));
}
if !(dst_pos + src_box.width() < (self.shape + 1)) {
return Err(String::from("Writing out of bounds"));
}
if !(self.data_in_c_order & src.data_in_c_order) {
return Err(String::from("Source and destination have to be in c-order"));
}

let length = src_box.width();

let num_channel = self.voxel_size / self.voxel_type.size();
let item_size = self.voxel_size / num_channel;

let channel_last_stride: Vec<usize> = vec![
(src.shape.z * src.shape.y * src.shape.x) as usize * item_size,
(src.shape.z * src.shape.y) as usize * item_size,
src.shape.z as usize * item_size,
item_size,
];

let channel_first_stride: Vec<usize> = vec![
item_size,
(self.shape.z * self.shape.y) as usize * self.voxel_size,
self.shape.z as usize * self.voxel_size,
self.voxel_size,
];

unsafe {
let src_ptr = src.data.as_ptr().add(src.offset(src_box.min()) / num_channel);
let dst_ptr = self.data.as_mut_ptr().add(self.offset(dst_pos));

for channel in 0..num_channel {
for x in 0..length.x {
for y in 0..length.y {
for z in 0..length.z {
let channel_last_index = linearize(channel, x as usize, y as usize, z as usize, &channel_last_stride);
let channel_first_index = linearize(channel, x as usize, y as usize, z as usize, &channel_first_stride);
ptr::copy_nonoverlapping(src_ptr.offset(channel_last_index), dst_ptr.offset(channel_first_index), item_size);
}
}
}
}
}
Ok(())
}
}

0 comments on commit b4f5262

Please sign in to comment.