diff --git a/python/tests/test_wkw.py b/python/tests/test_wkw.py index 94874da..0f7244e 100644 --- a/python/tests/test_wkw.py +++ b/python/tests/test_wkw.py @@ -288,6 +288,21 @@ def test_multiple_writes_and_reads(): assert np.all(mem_buffer == read_data) +def test_multi_channel_column_major_order(): + + with wkw.Dataset.create( + "tests/tmp", wkw.Header(np.uint8, num_channels=3) + ) as dataset: + offset = (30, 20, 10) + data_shape = (3, 100, 200, 300) + order = "C" + data = generate_test_data(np.uint8, list(data_shape), order) + dataset.write(offset, data) + + read_data = dataset.read(offset, data_shape[1:]) + assert np.all(data == read_data) + + def test_big_read(): data = np.ones((10, 10, 764), order="C", dtype=np.uint8) offset = np.array([0, 0, 640]) diff --git a/python/wkw/wkw.py b/python/wkw/wkw.py index e421694..f3b7634 100644 --- a/python/wkw/wkw.py +++ b/python/wkw/wkw.py @@ -209,8 +209,7 @@ def write(self, off, data): def is_contiguous(data): return data.flags["F_CONTIGUOUS"] or data.flags["C_CONTIGUOUS"] - # the row-major handling of the rust lib cannot handle num_channels > 1 - if self.header.num_channels != 1 or not is_contiguous(data): + if not is_contiguous(data): data = np.asfortranarray(data) box = _build_box(off, data.shape[-3:]) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 24f714b..924f42d 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -117,11 +117,6 @@ impl Dataset { return Err(format!("Input matrix has invalid voxel size {} != {}", mat.voxel_size, self.header.voxel_size as usize)); } - let num_channels = self.header.voxel_type.size() / self.header.voxel_size as usize; - if num_channels > 1 && mat.data_in_c_order { - return Err(String::from("Cannot write multichannel data if data is in row-major order.")); - } - let file_len_vx_log2 = self.header.file_len_vx_log2() as u32; if self.header.block_type == BlockType::LZ4 || self.header.block_type == BlockType::LZ4HC { let file_len_vec = Vec3::from(1 << file_len_vx_log2); diff --git a/rust/src/mat.rs b/rust/src/mat.rs index df85d8f..b1f1b50 100644 --- a/rust/src/mat.rs +++ b/rust/src/mat.rs @@ -11,6 +11,10 @@ pub struct Mat<'a> { pub data_in_c_order: bool, } +pub fn linearize(channel: usize, x: usize, y: usize, z: usize, stride: &[usize]) -> isize { + (channel * stride[0] + x * stride[1] + y * stride[2] + z * stride[3]) as isize +} + impl<'a> Mat<'a> { pub fn new( data: &mut [u8], @@ -79,38 +83,40 @@ impl<'a> Mat<'a> { let x_length = self.shape.x as usize; let y_length = self.shape.y as usize; - let z_length = self.shape.z as usize; + let num_channel = self.voxel_size / self.voxel_type.size(); + let item_size = self.voxel_size / num_channel; let row_major_stride: Vec = vec![ - z_length * y_length * self.voxel_size, - z_length * self.voxel_size, + item_size, + y_length * x_length * self.voxel_size, + y_length * self.voxel_size, self.voxel_size, ]; + let column_major_stride: Vec = vec![ + item_size, self.voxel_size, x_length * self.voxel_size, x_length * y_length * self.voxel_size, ]; - fn linearize(x: usize, y: usize, z: usize, stride: &[usize]) -> isize { - (x * stride[0] + y * stride[1] + z * stride[2]) as isize - } let src_ptr = self.data.as_ptr(); let dst_ptr = buffer_data.as_mut_ptr(); let from = src_bbox.min(); let to = src_bbox.max(); + // Do continuous read in z. Last dim in Row-Major is continuous. + let stripe_len = item_size * num_channel; for x in from.x as usize..to.x as usize { for y in from.y as usize..to.y as usize { for z in from.z as usize..to.z as usize { - let row_major_index = linearize(x, y, z, &row_major_stride); - let column_major_index = linearize(x, y, z, &column_major_stride); + let row_major_index = linearize(0, x, y, z, &row_major_stride); + let column_major_index = linearize(0, x, y, z, &column_major_stride); unsafe { - let cur_src_ptr = src_ptr.offset(row_major_index); - let cur_dst_ptr = dst_ptr.offset(column_major_index); - ptr::copy_nonoverlapping(cur_src_ptr, cur_dst_ptr, self.voxel_size); + ptr::copy_nonoverlapping(src_ptr.offset(row_major_index), dst_ptr.offset(column_major_index), stripe_len); } + } } } @@ -129,7 +135,14 @@ impl<'a> Mat<'a> { } if src.data_in_c_order { - intermediate_buffer.copy_from(dst_pos, src, src_box)?; + let num_channel = self.voxel_size / self.voxel_type.size(); + if num_channel == 1 { + // if the data has only one channel, copy_from is a bit faster because it copies more items simultaneously + intermediate_buffer.copy_from(dst_pos, src, src_box)?; + } else { + // putting the channels to the back avoids that the indices (in copy_as_fortran_order) make too big jumps + intermediate_buffer.copy_from_and_put_channels_last(dst_pos, src, src_box)?; + } let dst_bbox = Box3::new(dst_pos, dst_pos + src_box.width())?; intermediate_buffer.copy_as_fortran_order(self, dst_bbox) } else { @@ -208,4 +221,60 @@ impl<'a> Mat<'a> { } Ok(()) } + + pub fn copy_from_and_put_channels_last(&mut self, dst_pos: Vec3, src: &Mat, src_box: Box3) -> Result<()> { + // make sure that matrices are matching + if self.voxel_size != src.voxel_size { + return Err(format!("Matrices mismatch in voxel size {} != {}", self.voxel_size, src.voxel_size)); + } + if self.voxel_type != src.voxel_type { + return Err(format!("Matrices mismatch in voxel type {:?} != {:?}", self.voxel_type, src.voxel_type)); + } + if !(src_box.max() < (src.shape + 1)) { + return Err(String::from("Reading out of bounds")); + } + if !(dst_pos + src_box.width() < (self.shape + 1)) { + return Err(String::from("Writing out of bounds")); + } + if !(self.data_in_c_order & src.data_in_c_order) { + return Err(String::from("Source and destination have to be in c-order")); + } + + let length = src_box.width(); + + let num_channel = self.voxel_size / self.voxel_type.size(); + let item_size = self.voxel_size / num_channel; + + let channel_last_stride: Vec = vec![ + (src.shape.z * src.shape.y * src.shape.x) as usize * item_size, + (src.shape.z * src.shape.y) as usize * item_size, + src.shape.z as usize * item_size, + item_size, + ]; + + let channel_first_stride: Vec = vec![ + item_size, + (self.shape.z * self.shape.y) as usize * self.voxel_size, + self.shape.z as usize * self.voxel_size, + self.voxel_size, + ]; + + unsafe { + let src_ptr = src.data.as_ptr().add(src.offset(src_box.min()) / num_channel); + let dst_ptr = self.data.as_mut_ptr().add(self.offset(dst_pos)); + + for channel in 0..num_channel { + for x in 0..length.x { + for y in 0..length.y { + for z in 0..length.z { + let channel_last_index = linearize(channel, x as usize, y as usize, z as usize, &channel_last_stride); + let channel_first_index = linearize(channel, x as usize, y as usize, z as usize, &channel_first_stride); + ptr::copy_nonoverlapping(src_ptr.offset(channel_last_index), dst_ptr.offset(channel_first_index), item_size); + } + } + } + } + } + Ok(()) + } }