From cca04e71f1e7d7534ac41256eed92da4db851d71 Mon Sep 17 00:00:00 2001 From: dagou Date: Thu, 4 Jul 2024 11:08:48 +0800 Subject: [PATCH] bug fix --- .github/workflows/rust.yml | 2 +- kr2r/src/bin/annotate.rs | 4 ++-- kr2r/src/compact_hash.rs | 20 ++++++++------------ seqkmer/src/parallel.rs | 13 +++++++------ 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f73e863..7fd850f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -10,7 +10,7 @@ on: env: CARGO_TERM_COLOR: always BINARIES_LIST: 'ncbi kun_peng' - PROJECT_PREFIX: 'kraken2-rust-' + PROJECT_PREFIX: 'Kun-peng-' jobs: build-and-release: diff --git a/kr2r/src/bin/annotate.rs b/kr2r/src/bin/annotate.rs index b4613b8..d69f75d 100644 --- a/kr2r/src/bin/annotate.rs +++ b/kr2r/src/bin/annotate.rs @@ -111,7 +111,7 @@ where reader, num_cpus::get(), batch_size, - |dataset: &[Slot]| { + |dataset: Vec>| { let mut results: HashMap> = HashMap::new(); for slot in dataset { let indx = slot.idx & idx_mask; @@ -176,6 +176,7 @@ fn process_chunk_file>( let start = Instant::now(); + println!("start load table..."); let config = HashConfig::from_hash_header(&args.database.join("hash_config.k2d"))?; let chtm = CHTable::from_range( config, @@ -209,7 +210,6 @@ pub fn run(args: Args) -> Result<()> { let start = Instant::now(); println!("annotate start..."); for chunk_file in chunk_files { - println!("chunk_file {:?}", chunk_file); process_chunk_file(&args, chunk_file, &hash_files)?; } // 计算持续时间 diff --git a/kr2r/src/compact_hash.rs b/kr2r/src/compact_hash.rs index 288be41..a51197d 100644 --- a/kr2r/src/compact_hash.rs +++ b/kr2r/src/compact_hash.rs @@ -277,7 +277,7 @@ fn read_page_from_file>(filename: P) -> Result { let capacity = LittleEndian::read_u64(&buffer[8..16]) as usize; // 读取数据部分 - let mut data = vec![0u32; capacity]; + let mut data = vec![0u32; capacity + 1024 * 1024]; let data_bytes = unsafe { std::slice::from_raw_parts_mut( data.as_mut_ptr() as *mut u8, @@ -299,7 +299,7 @@ fn read_first_block_from_file>(filename: P) -> Result { let capacity = LittleEndian::read_u64(&buffer[8..16]) as usize; let mut first_zero_end = capacity; - let chunk_size = 1024; // Define the chunk size for reading + let chunk_size = 1024 * 4; let mut found_zero = false; let mut data = vec![0u32; capacity]; let mut read_pos = 0; @@ -373,17 +373,14 @@ impl Page { value_bits: usize, value_mask: usize, ) -> u32 { - // let compacted_key = value.left(value_bits) as u32; let mut idx = index; - if idx > self.size { - return u32::default(); + if idx >= self.size { + return 0; } loop { if let Some(cell) = self.data.get(idx) { - if cell.right(value_mask) == u32::default() - || cell.left(value_bits) == compacted_key - { + if cell.right(value_mask) == 0 || cell.left(value_bits) == compacted_key { return cell.right(value_mask); } @@ -392,11 +389,10 @@ impl Page { break; } } else { - // 如果get(idx)失败,返回默认值 - return u32::default(); + return 0; } } - u32::default() + 0 } } @@ -428,7 +424,7 @@ impl CHTable { for i in start..end { let mut hash_file = &hash_sorted_files[i]; let mut page = read_page_from_file(&hash_file)?; - let next_page = if page.data.last().map_or(false, |&x| x == 0) { + let next_page = if page.data.last().map_or(false, |&x| x != 0) { if kd_type { hash_file = &hash_sorted_files[(i + 1) % parition] } diff --git a/seqkmer/src/parallel.rs b/seqkmer/src/parallel.rs index ba4f859..290a8b2 100644 --- a/seqkmer/src/parallel.rs +++ b/seqkmer/src/parallel.rs @@ -7,7 +7,6 @@ use scoped_threadpool::Pool; use std::collections::HashMap; use std::io::Result; use std::sync::Arc; - pub struct ParallelResult

where P: Send, @@ -108,16 +107,16 @@ pub fn buffer_read_parallel( func: F, ) -> Result<()> where - D: Send + Sized + Sync, + D: Send + Sized + Sync + Clone, R: std::io::Read + Send, O: Send, Out: Send + Default, - W: Send + Sync + Fn(&[D]) -> Option, + W: Send + Sync + Fn(Vec) -> Option, F: FnOnce(&mut ParallelResult>) -> Out + Send, { assert!(n_threads > 2); let buffer_len = n_threads + 2; - let (sender, receiver) = bounded::<&[D]>(buffer_len); + let (sender, receiver) = bounded::>(buffer_len); let (done_send, done_recv) = bounded::>(buffer_len); let receiver = Arc::new(receiver); // 使用 Arc 来共享 receiver let done_send = Arc::new(done_send); @@ -140,7 +139,9 @@ where let slots = unsafe { std::slice::from_raw_parts(batch_buffer.as_ptr() as *const D, slots_in_batch) }; - sender.send(slots).expect("Failed to send sequences"); + sender + .send(slots.to_vec()) + .expect("Failed to send sequences"); } }); @@ -163,7 +164,7 @@ where let _ = func(&mut parallel_result); }); - pool_scope.join_all(); + // pool_scope.join_all(); }); Ok(())