From 98b6e502d45fa84e7362b5d62d1905eb25b601ec Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Sun, 27 Oct 2024 12:53:00 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8Implemented=20`MmapArray1`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 10 +++++ cfpyo3_rs_core/Cargo.toml | 3 +- cfpyo3_rs_core/src/toolkit/array.rs | 61 +++++++++++++++++++++++++++-- 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 617670d..5e4abb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,6 +188,7 @@ dependencies = [ "futures", "itertools 0.13.0", "md-5", + "memmap2", "ndarray-rand", "num-traits", "numpy", @@ -923,6 +924,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.9.1" diff --git a/cfpyo3_rs_core/Cargo.toml b/cfpyo3_rs_core/Cargo.toml index 7e0dfbf..149aa2f 100644 --- a/cfpyo3_rs_core/Cargo.toml +++ b/cfpyo3_rs_core/Cargo.toml @@ -31,6 +31,7 @@ opendal = { version = "0.50.0", features = ["services-s3"], optional = true } rand = { version = "0.8.5", optional = true } redis = { version = "0.26.1", features = ["cluster"], optional = true } tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread"], optional = true } +memmap2 = "0.9.5" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } @@ -84,4 +85,4 @@ required-features = ["criterion"] [[bench]] name = "df" harness = false -required-features = ["criterion"] \ No newline at end of file +required-features = ["criterion"] diff --git a/cfpyo3_rs_core/src/toolkit/array.rs b/cfpyo3_rs_core/src/toolkit/array.rs index 1104768..0009305 100644 --- a/cfpyo3_rs_core/src/toolkit/array.rs +++ b/cfpyo3_rs_core/src/toolkit/array.rs @@ -1,14 +1,20 @@ -use anyhow::Result; -use core::{mem, ptr}; +use anyhow::{Ok, Result}; +use core::{mem, ptr, slice}; use itertools::{izip, Itertools}; +use memmap2::{Mmap, MmapOptions}; use num_traits::{Float, FromPrimitive}; -use numpy::ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand}; +use numpy::{ + ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand}, + Element, +}; use std::{ cell::UnsafeCell, cmp::Ordering, collections::HashMap, fmt::{Debug, Display}, + fs::File, iter::zip, + marker::PhantomData, ops::{AddAssign, MulAssign, SubAssign}, thread::available_parallelism, }; @@ -82,6 +88,35 @@ impl<'a, T> UnsafeSlice<'a, T> { } } +pub struct MmapArray1(Mmap, usize, PhantomData); +impl MmapArray1 { + /// # Safety + /// + /// The use of `mmap` is unsafe, see the documentation of [`MmapOptions`] for more details. + pub unsafe fn new(path: &str) -> Result { + let file = File::open(path)?; + let mmap = unsafe { MmapOptions::new().map(&file)? }; + let len = mmap.len(); + Ok(Self(mmap, len, PhantomData)) + } + + /// # Safety + /// + /// The use of [`slice::from_raw_parts`] is unsafe, see its documentation for more details. + pub unsafe fn as_slice(&self) -> &[T] { + slice::from_raw_parts(self.0.as_ptr() as *const T, self.1) + } + + /// # Safety + /// + /// The use of [`ArrayView1::from_shape_ptr`] is unsafe, see its documentation for more details. + pub unsafe fn as_array_view(&self) -> ArrayView1 { + ArrayView1::from_shape_ptr((self.1,), self.0.as_ptr() as *const T) + } +} + +// float ops + pub trait AFloat: Float + AddAssign @@ -759,6 +794,9 @@ pub fn fast_concat_2d_axis0( #[cfg(test)] mod tests { use super::*; + use crate::toolkit::convert::to_bytes; + use std::io::Write; + use tempfile::tempdir; fn assert_allclose(a: &[T], b: &[T]) { let atol = T::from_f64(1e-6).unwrap(); @@ -773,6 +811,23 @@ mod tests { }); } + #[test] + fn test_mmap() { + let dir = tempdir().unwrap(); + let file_path = dir.path().join("test.cfy"); + let array = Array1::::from_shape_vec(3, vec![1., 2., 3.]).unwrap(); + let bytes = unsafe { to_bytes(array.as_slice().unwrap()) }; + let mut file = File::create(&file_path).unwrap(); + file.write_all(bytes).unwrap(); + let file_path = file_path.to_str().unwrap(); + let mmap_array = unsafe { MmapArray1::::new(file_path).unwrap() }; + assert_allclose(array.as_slice().unwrap(), unsafe { mmap_array.as_slice() }); + assert_allclose( + array.as_slice().unwrap(), + unsafe { mmap_array.as_array_view() }.as_slice().unwrap(), + ); + } + macro_rules! test_fast_concat_2d_axis0 { ($dtype:ty) => { let array_2d_u = ArrayView2::<$dtype>::from_shape((1, 3), &[1., 2., 3.]).unwrap();