Skip to content

Commit

Permalink
Merge pull request #10 from balena-io-experimental/multi-threading
Browse files Browse the repository at this point in the history
Make implementation multi-threading compatible
  • Loading branch information
pipex authored Feb 16, 2025
2 parents 5dadeb0 + 95d21b8 commit 62f13a8
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 111 deletions.
71 changes: 41 additions & 30 deletions src/dag.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::cell::RefCell;
use std::fmt;
use std::ops::Add;
use std::rc::Rc;
use std::sync::{Arc, RwLock};

type Link<T> = Option<Rc<RefCell<Node<T>>>>;
type Link<T> = Option<Arc<RwLock<Node<T>>>>;

/**
* A node in a DAG is a recursive data structure that
Expand Down Expand Up @@ -41,7 +40,7 @@ impl<T> Node<T> {
}

pub fn into_link(self) -> Link<T> {
Some(Rc::new(RefCell::new(self)))
Some(Arc::new(RwLock::new(self)))
}
}

Expand All @@ -54,12 +53,12 @@ pub(crate) struct Iter<T> {
}

impl<T> Iterator for Iter<T> {
type Item = Rc<RefCell<Node<T>>>;
type Item = Arc<RwLock<Node<T>>>;

fn next(&mut self) -> Option<Self::Item> {
while let Some((link, branching)) = self.stack.pop() {
if let Some(node_rc) = link {
let node_ref = node_rc.borrow();
let node_ref = node_rc.read().unwrap();
match &*node_ref {
Node::Item { next, .. } => {
// Push the next node onto the stack for continuation
Expand Down Expand Up @@ -153,7 +152,7 @@ impl<T> Dag<T> {
let Dag { head, tail } = dag;

if let Some(tail_rc) = tail {
match *tail_rc.borrow_mut() {
match *tail_rc.write().unwrap() {
Node::Item { ref mut next, .. } => *next = self.head,
Node::Join { ref mut next } => *next = self.head,
_ => unreachable!(),
Expand All @@ -173,7 +172,7 @@ impl<T> Dag<T> {
pub fn concat(self, other: Dag<T>) -> Self {
debug_assert!(self.tail.is_some());
if let Some(tail_node) = self.tail {
match *tail_node.borrow_mut() {
match *tail_node.write().unwrap() {
Node::Item { ref mut next, .. } => {
*next = other.head;
}
Expand Down Expand Up @@ -205,7 +204,7 @@ impl<T> Dag<T> {
/// condition given as argument
pub fn some(&self, condition: impl Fn(&T) -> bool) -> bool {
for node in self.iter() {
if let Node::Item { value, .. } = &*node.borrow() {
if let Node::Item { value, .. } = &*node.read().unwrap() {
if condition(value) {
return true;
}
Expand All @@ -218,7 +217,7 @@ impl<T> Dag<T> {
/// as argument
pub fn every(&self, condition: impl Fn(&T) -> bool) -> bool {
for node in self.iter() {
if let Node::Item { value, .. } = &*node.borrow() {
if let Node::Item { value, .. } = &*node.read().unwrap() {
if !condition(value) {
return false;
}
Expand Down Expand Up @@ -256,7 +255,7 @@ impl<T> Dag<T> {
pub fn fold<U>(&self, initial: U, fold_fn: impl Fn(U, &T) -> U) -> U {
let mut acc = initial;
for node in self.iter() {
if let Node::Item { value, .. } = &*node.borrow() {
if let Node::Item { value, .. } = &*node.read().unwrap() {
acc = fold_fn(acc, value);
}
}
Expand Down Expand Up @@ -290,7 +289,7 @@ impl<T> Dag<T> {
for value in iter {
let new_node = Node::item(value.into(), None).into_link();
if let Some(tail_node) = tail {
if let Node::Item { ref mut next, .. } = *tail_node.borrow_mut() {
if let Node::Item { ref mut next, .. } = *tail_node.write().unwrap() {
*next = new_node.clone();
}
}
Expand Down Expand Up @@ -340,7 +339,7 @@ impl<T> Dag<T> {
debug_assert!(branch.tail.is_some());
// Link each branch tail to the join node
if let Some(tail_rc) = branch.tail {
match *tail_rc.borrow_mut() {
match *tail_rc.write().unwrap() {
Node::Item { ref mut next, .. } => {
*next = tail.clone();
}
Expand Down Expand Up @@ -393,7 +392,7 @@ impl<T> Dag<T> {

while let Some((head, prev, branching)) = stack.pop() {
if let Some(node_rc) = head.clone() {
match *node_rc.borrow_mut() {
match *node_rc.write().unwrap() {
Node::Item { ref mut next, .. } => {
// copy the next node to continue the operation
let newhead = next.clone();
Expand Down Expand Up @@ -500,7 +499,7 @@ impl<T: fmt::Display> fmt::Display for Dag<T> {
write!(f, "- {}", value)?;

if let Some(next_rc) = next {
fmt_node(f, &*next_rc.borrow(), indent, index + 1, branching)?;
fmt_node(f, &*next_rc.read().unwrap(), indent, index + 1, branching)?;
}
}
Node::Fork { next } => {
Expand All @@ -515,7 +514,13 @@ impl<T: fmt::Display> fmt::Display for Dag<T> {
let mut updated_branching = branching.clone();
updated_branching.push((index, br_idx == next.len() - 1));

fmt_node(f, &*branch_head.borrow(), indent + 2, 0, updated_branching)?;
fmt_node(
f,
&*branch_head.read().unwrap(),
indent + 2,
0,
updated_branching,
)?;
}
}
}
Expand All @@ -524,7 +529,13 @@ impl<T: fmt::Display> fmt::Display for Dag<T> {
if let Some((index, is_last)) = branching.pop() {
if is_last {
if let Some(next_rc) = next {
fmt_node(f, &*next_rc.borrow(), indent - 2, index + 1, branching)?;
fmt_node(
f,
&*next_rc.read().unwrap(),
indent - 2,
index + 1,
branching,
)?;
}
}
}
Expand All @@ -536,7 +547,7 @@ impl<T: fmt::Display> fmt::Display for Dag<T> {
if let Some(root) = &self.head {
fmt_node(
f,
&*root.borrow(),
&*root.read().unwrap(),
0, // Initial indent level
0, // Initial index
Vec::new(), // Initial branching
Expand Down Expand Up @@ -565,8 +576,8 @@ mod tests {
use super::*;
use dedent::dedent;

fn is_item<T>(node: &Rc<RefCell<Node<T>>>) -> bool {
if let Node::Item { .. } = &*node.borrow() {
fn is_item<T>(node: &Arc<RwLock<Node<T>>>) -> bool {
if let Node::Item { .. } = &*node.read().unwrap() {
return true;
}
false
Expand All @@ -591,7 +602,7 @@ mod tests {
if let Node::Item {
value: node_value,
next,
} = &*head_rc.borrow()
} = &*head_rc.read().unwrap()
{
assert_eq!(*node_value, value);
head = next.clone();
Expand All @@ -609,14 +620,14 @@ mod tests {

assert!(dag.head.is_some());
if let Some(head_rc) = dag.head {
if let Node::Item { value, .. } = &*head_rc.borrow() {
if let Node::Item { value, .. } = &*head_rc.read().unwrap() {
assert_eq!(value, &1)
}
}

assert!(dag.tail.is_some());
if let Some(tail_rc) = dag.tail {
if let Node::Item { value, .. } = &*tail_rc.borrow() {
if let Node::Item { value, .. } = &*tail_rc.read().unwrap() {
assert_eq!(value, &4)
}
}
Expand All @@ -631,7 +642,7 @@ mod tests {
let mut result = Vec::new();

for node in dag.iter() {
let node_ref = node.borrow();
let node_ref = node.read().unwrap();
match &*node_ref {
Node::Item { value, .. } => result.push(*value), // Collect the value
Node::Fork { .. } => panic!("unexpected fork node in a linear graph"),
Expand All @@ -654,7 +665,7 @@ mod tests {
let elems: Vec<i32> = dag
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand All @@ -668,7 +679,7 @@ mod tests {
let elems: Vec<i32> = dag
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand All @@ -687,7 +698,7 @@ mod tests {
let elems: Vec<i32> = dag
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand All @@ -707,7 +718,7 @@ mod tests {
let elems: Vec<i32> = dag
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand All @@ -727,7 +738,7 @@ mod tests {
let mut reverse: Vec<i32> = dag
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand All @@ -738,7 +749,7 @@ mod tests {
.reverse()
.iter()
.filter(is_item)
.map(|node| match &*node.borrow() {
.map(|node| match &*node.read().unwrap() {
Node::Item { value, .. } => *value,
_ => unreachable!(),
})
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub enum Error {
WorkflowInterrupted(#[from] super::worker::Interrupted),

#[error(transparent)]
Other(#[from] Box<dyn std::error::Error>),
Other(#[from] Box<dyn std::error::Error + Send>),
}

pub trait IntoError {
Expand Down
6 changes: 3 additions & 3 deletions src/task/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ impl BoxedIntoTask {
pub fn from_action<A, T, I>(action: A) -> Self
where
A: Handler<T, Patch, I>,
I: 'static,
I: Send + 'static,
{
Self(Box::new(MakeIntoTask {
handler: action,
Expand Down Expand Up @@ -37,7 +37,7 @@ impl Clone for BoxedIntoTask {
}
}

trait ErasedIntoTask: Send {
trait ErasedIntoTask: Send + Sync {
fn clone_box(&self) -> Box<dyn ErasedIntoTask>;

fn into_task(self: Box<Self>, id: &'static str, context: Context) -> Task;
Expand All @@ -50,7 +50,7 @@ struct MakeIntoTask<H> {

impl<H> ErasedIntoTask for MakeIntoTask<H>
where
H: Send + Clone + 'static,
H: Send + Sync + Clone + 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoTask> {
Box::new(Self {
Expand Down
26 changes: 16 additions & 10 deletions src/task/effect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ impl<O, E, I> IntoEffect<O, E, I> for Effect<O, E, I> {
}
}

type IOResult<O, E> = Pin<Box<dyn Future<Output = Result<O, E>>>>;
type IO<O, E = Infallible, I = O> = Box<dyn FnOnce(I) -> IOResult<O, E>>;
type Pure<O, E, I> = Box<dyn FnOnce(I) -> Result<O, E>>;
type IOResult<O, E> = Pin<Box<dyn Future<Output = Result<O, E>> + Send>>;
type IO<O, E = Infallible, I = O> = Box<dyn FnOnce(I) -> IOResult<O, E> + Send>;
type Pure<O, E, I> = Box<dyn FnOnce(I) -> Result<O, E> + Send>;

pub enum Effect<O, E = Infallible, I = O> {
Pure(Result<O, E>),
Expand All @@ -32,12 +32,15 @@ impl<O, E> Effect<O, E> {
Effect::Pure(Ok(o))
}

pub fn with_io<F: FnOnce(O) -> Res + 'static, Res: Future<Output = Result<O, E>>>(
pub fn with_io<
F: FnOnce(O) -> Res + Send + 'static,
Res: Future<Output = Result<O, E>> + Send,
>(
self,
f: F,
) -> Effect<O, E>
where
O: 'static,
O: Send + 'static,
{
let io: IO<O, E> = Box::new(|o| Box::pin(async { f(o).await }));
let pure = Box::new(|o| Ok(o));
Expand All @@ -52,12 +55,12 @@ impl<O, E> Effect<O, E> {
}
}

impl<T: 'static, E: 'static, I: 'static> Effect<T, E, I> {
impl<T: 'static, E: 'static, I: Send + 'static> Effect<T, E, I> {
pub fn from_error(e: E) -> Self {
Effect::Pure(Err(e))
}

pub fn map<O, F: FnOnce(T) -> O + Clone + 'static>(self, fu: F) -> Effect<O, E, I> {
pub fn map<O, F: FnOnce(T) -> O + Clone + Send + 'static>(self, fu: F) -> Effect<O, E, I> {
match self {
Effect::Pure(output) => Effect::Pure(output.map(fu)),
Effect::IO { input, pure, io } => {
Expand All @@ -76,7 +79,7 @@ impl<T: 'static, E: 'static, I: 'static> Effect<T, E, I> {
}
}

pub fn map_io<F: FnOnce(T) -> T + 'static>(self, fu: F) -> Effect<T, E, I> {
pub fn map_io<F: FnOnce(T) -> T + Send + 'static>(self, fu: F) -> Effect<T, E, I> {
match self {
Effect::Pure(output) => Effect::Pure(output),
Effect::IO { input, pure, io } => Effect::IO {
Expand All @@ -92,7 +95,7 @@ impl<T: 'static, E: 'static, I: 'static> Effect<T, E, I> {
}
}

pub fn and_then<O, F: FnOnce(T) -> Result<O, E> + Clone + 'static>(
pub fn and_then<O, F: FnOnce(T) -> Result<O, E> + Clone + Send + 'static>(
self,
fu: F,
) -> Effect<O, E, I> {
Expand All @@ -114,7 +117,10 @@ impl<T: 'static, E: 'static, I: 'static> Effect<T, E, I> {
}
}

pub fn map_err<E1, F: FnOnce(E) -> E1 + Clone + 'static>(self, fe: F) -> Effect<T, E1, I> {
pub fn map_err<E1, F: FnOnce(E) -> E1 + Clone + Send + 'static>(
self,
fe: F,
) -> Effect<T, E1, I> {
match self {
Effect::Pure(output) => Effect::Pure(output.map_err(fe)),
Effect::IO { input, pure, io } => {
Expand Down
Loading

0 comments on commit 62f13a8

Please sign in to comment.