Skip to content

Commit

Permalink
💥🗑️Simplified rt stuffs
Browse files Browse the repository at this point in the history
turns out that initializing runtimes is pretty fast!
  • Loading branch information
carefree0910 committed Oct 21, 2024
1 parent 9619195 commit 38b58ba
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 47 deletions.
44 changes: 12 additions & 32 deletions cfpyo3_rs_core/src/toolkit/misc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use md5::{Digest, Md5};
#[cfg(feature = "tokio")]
use std::sync::LazyLock;
use anyhow::Result;
use md5::{Digest, Md5};
use std::{collections::HashMap, fmt, sync::RwLock, time::Instant};
#[cfg(feature = "tokio")]
use tokio::runtime::{Builder, Runtime};
Expand Down Expand Up @@ -190,34 +190,14 @@ impl NamedTrackers {
// tokio utils

#[cfg(feature = "tokio")]
fn init_rt(num_threads: usize) -> Runtime {
if num_threads <= 1 {
return Builder::new_current_thread().enable_all().build().unwrap();
}
Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.unwrap()
}

#[cfg(feature = "tokio")]
static RT_POOL: LazyLock<HashMap<usize, Runtime>> = LazyLock::new(|| {
let mut pool = HashMap::new();
pool.insert(1, init_rt(1));
pool.insert(2, init_rt(2));
pool.insert(4, init_rt(4));
pool
});

/// Get a tokio runtime with specific number of threads.
///
/// # Panics
///
/// Currently, only 1, 2, and 4 threads are supported, other numbers will cause panic.
#[cfg(feature = "tokio")]
pub fn get_rt<'a>(num_threads: usize) -> &'a Runtime {
RT_POOL
.get(&num_threads)
.unwrap_or_else(|| panic!("No runtime for {} threads", num_threads))
pub fn init_rt(num_threads: usize) -> Result<Runtime> {
let rt = if num_threads <= 1 {
Builder::new_current_thread().enable_all().build()?
} else {
Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()?
};
Ok(rt)
}
27 changes: 12 additions & 15 deletions cfpyo3_rs_core/src/toolkit/queue.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,50 @@
use super::misc::get_rt;
use super::misc::init_rt;
use anyhow::Result;
use core::marker::PhantomData;
use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
};
use tokio::task::JoinHandle;
use tokio::{runtime::Runtime, task::JoinHandle};

pub trait WithQueueThreads {
fn get_queue_threads(&self) -> usize;
}
pub trait Worker<T, R>: Send + Sync
where
T: Send + Sync + WithQueueThreads,
T: Send + Sync,
R: Send + Sync,
{
fn process(&self, cursor: usize, data: T) -> Result<R>;
}

pub struct AsyncQueue<T, R>
where
T: Send + Sync + WithQueueThreads,
T: Send + Sync,
R: Send + Sync,
{
rt: Runtime,
worker: Arc<RwLock<Box<dyn Worker<T, R>>>>,
results: Arc<Mutex<HashMap<usize, Result<R>>>>,
pending: Vec<JoinHandle<()>>,
phantom_task_data: PhantomData<T>,
}
impl<T, R> AsyncQueue<T, R>
where
T: Send + Sync + WithQueueThreads + 'static,
T: Send + Sync + 'static,
R: Send + Sync + 'static,
{
pub fn new(worker: Box<dyn Worker<T, R>>) -> Self {
Self {
pub fn new(worker: Box<dyn Worker<T, R>>, num_threads: usize) -> Result<Self> {
Ok(Self {
rt: init_rt(num_threads)?,
worker: Arc::new(RwLock::new(worker)),
results: Arc::new(Mutex::new(HashMap::new())),
pending: Vec::new(),
phantom_task_data: PhantomData,
}
})
}

pub fn submit(&mut self, cursor: usize, data: T) {
let worker = Arc::clone(&self.worker);
let results = Arc::clone(&self.results);
let rt = get_rt(data.get_queue_threads());
let handle = rt.spawn(async move {
let handle = self.rt.spawn(async move {
let result = worker.read().unwrap().process(cursor, data);
results.lock().unwrap().insert(cursor, result);
});
Expand All @@ -60,12 +58,11 @@ where
pub fn reset(&mut self, block_after_abort: bool) -> Result<()> {
use anyhow::Ok;

let rt = get_rt(1);
self.results.lock().unwrap().clear();
self.pending.drain(..).try_for_each(|handle| {
handle.abort();
if block_after_abort {
rt.block_on(handle)?;
self.rt.block_on(handle)?;
}
Ok(())
})?;
Expand Down

0 comments on commit 38b58ba

Please sign in to comment.