Skip to content

Commit

Permalink
Implement first version of the Worker
Browse files Browse the repository at this point in the history
Change-type: minor
  • Loading branch information
pipex committed Feb 14, 2025
1 parent 2ad78e6 commit b2d1336
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ matchit = "0.8.4"
serde = { version = "1.0.197" }
serde_json = "1.0.120"
thiserror = "2"
tokio = { version = "1.43.0", features = ["rt"] }

[dev-dependencies]
dedent = "0.1.1"
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
mod dag;
mod error;
mod path;

pub mod error;
pub use error::*;
pub mod extract;
pub mod system;
pub mod task;
Expand Down
1 change: 0 additions & 1 deletion src/system/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ pub struct System {
state: Value,
}

// TODO: replace with TryFrom implementation
impl<S> From<S> for System
where
S: Serialize,
Expand Down
4 changes: 2 additions & 2 deletions src/task/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Clone for BoxedIntoTask {
}
}

trait ErasedIntoTask {
trait ErasedIntoTask: Send {
fn clone_box(&self) -> Box<dyn ErasedIntoTask>;

fn into_task(self: Box<Self>, id: &'static str, context: Context) -> Task;
Expand All @@ -50,7 +50,7 @@ struct MakeIntoTask<H> {

impl<H> ErasedIntoTask for MakeIntoTask<H>
where
H: Clone + 'static,
H: Send + Clone + 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoTask> {
Box::new(Self {
Expand Down
221 changes: 221 additions & 0 deletions src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,227 @@ mod intent;
mod planner;
mod workflow;

use log::warn;
use serde_json::Value;
use std::{
future::{self, Future},
pin::Pin,
};
use tokio::task::{self, JoinHandle};

pub use domain::*;
pub use intent::*;
pub use planner::*;
use serde::{de::DeserializeOwned, Serialize};

use crate::{error::Error, system::System};

pub struct WorkerOpts {
/// The maximum number of attempts to reach the target before giving up.
/// Defauts to infinite tries (0).
max_retries: u32,
/// The minimal time to wait between re-plan. Defaults to 1 second
min_wait_ms: u64,
/// The maximum time to wait between re-plan. Defaults to 5 minutes
max_wait_ms: u64,
}

impl WorkerOpts {
pub fn max_retries(self, max_retries: u32) -> Self {
let mut opts = self;
opts.max_retries = max_retries;
opts
}

pub fn min_wait_ms(self, min_wait_ms: u64) -> Self {
let mut opts = self;
opts.min_wait_ms = min_wait_ms;
opts
}

pub fn max_wait_ms(self, max_wait_ms: u64) -> Self {
let mut opts = self;
opts.max_wait_ms = max_wait_ms;
opts
}
}

pub trait WorkerState {}

pub struct Uninitialized {
domain: Domain,
opts: WorkerOpts,
}

pub struct Ready {
planner: Planner,
system: System,
opts: WorkerOpts,
}

pub struct Running {
task: JoinHandle<(Planner, System)>,
}

impl WorkerState for Uninitialized {}
impl WorkerState for Ready {}
impl WorkerState for Running {}

impl Default for WorkerOpts {
fn default() -> Self {
WorkerOpts {
max_retries: 0,
min_wait_ms: 1000,
max_wait_ms: 300_000,
}
}
}

pub struct Worker<T, S: WorkerState = Uninitialized> {
inner: S,
_marker: std::marker::PhantomData<T>,
}

impl<T> Default for Worker<T, Uninitialized> {
fn default() -> Self {
Worker::from_inner(Uninitialized {
domain: Domain::new(),
opts: WorkerOpts::default(),
})
}
}

impl<T, S: WorkerState> Worker<T, S> {
fn from_inner(inner: S) -> Self {
Worker {
inner,
_marker: std::marker::PhantomData,
}
}
}

impl<T> Worker<T, Uninitialized> {
pub fn new() -> Self {
Worker::default()
}

pub fn job(self, route: &'static str, intent: Intent) -> Self {
let Self { mut inner, .. } = self;
inner.domain = inner.domain.job(route, intent);
Worker::from_inner(inner)
}

pub fn with_domain(self, domain: Domain) -> Worker<T, Uninitialized> {
let Self { mut inner, .. } = self;
inner.domain = domain;
Worker::from_inner(inner)
}

pub fn with_opts(self, opts: WorkerOpts) -> Worker<T, Uninitialized> {
let Self { mut inner, .. } = self;
inner.opts = opts;
Worker::from_inner(inner)
}

pub fn with_state(self, state: T) -> Worker<T, Ready>
where
T: Serialize + DeserializeOwned,
{
let Uninitialized { domain, opts, .. } = self.inner;
// this can panic
let system = System::from(state);
Worker::from_inner(Ready {
planner: Planner::new(domain),
system,
opts,
})
}
}

impl<T: Serialize + DeserializeOwned> Worker<T, Ready> {
pub fn state(self) -> T {
self.inner.system.state().unwrap()
}

pub fn seek(self, tgt: T) -> Worker<T, Running> {
let Ready {
planner,
system,
opts,
..
} = self.inner;

// TODO: handle the error
let tgt = serde_json::to_value(tgt).unwrap();

enum RuntimeResult {
Done,
Continue,
Cancelled,
}

async fn seek_target(
planner: &Planner,
system: &mut System,
tgt: &Value,
) -> Result<RuntimeResult, Error> {
// TODO: maybe use a timeout to finding the plan
let workflow = planner.find_workflow(system, tgt.clone())?;
if workflow.is_empty() {
return Ok(RuntimeResult::Done);
}

// TODO: run the plan and update the system
workflow.execute(system).await?;

Ok(RuntimeResult::Continue)
}

// TODO: while we have not reached the target state
// find a plan to the target using the planner
// if the plan is empty, we have reached the target
// if we cannot find a plan, try again
// if something fails during the plan, try again
let task = task::spawn_local(async move {
let mut system = system;
// TODO: allow to cancel the planning
// we need to use a channel for this
let mut tries = 0;
loop {
let found = match seek_target(&planner, &mut system, &tgt).await {
Ok(RuntimeResult::Done) => break,
Ok(RuntimeResult::Continue) => true,
// TODO: we probably need to return an error here
Ok(RuntimeResult::Cancelled) => break,
// TODO: Handle errors
Err(Error::PlanSearchFailed(PlanningError::WorkflowNotFound)) => false,
Err(Error::PlanSearchFailed(e)) => {
if cfg!(debug_assertions) {
panic!("unexpected error during planning: {}", e);
} else {
warn!("unexpected error during planning: {}", e);
}
true
}
// TODO: log the error
Err(_) => false,
};

if !found && tries >= opts.max_retries {
// TODO: handle the error
panic!("could not find a plan to the target state");
}

// Exponential backoff
let wait = std::cmp::min(opts.min_wait_ms * 2u64.pow(tries), opts.max_wait_ms);
tokio::time::sleep(tokio::time::Duration::from_millis(wait)).await;

// Only backoff if we did not find the target
tries += if found { 0 } else { 1 };
}
(planner, system)
});

Worker::from_inner(Running { task })
}
}
16 changes: 9 additions & 7 deletions src/worker/planner.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use json_patch::Patch;
use log::warn;
use serde::Serialize;
use serde_json::Value;
use std::fmt::{self, Display};

use crate::error::{Error, IntoError};
Expand All @@ -13,9 +14,7 @@ use super::domain::Domain;
use super::workflow::{Action, Workflow};
use super::{DomainSearchError, Operation};

pub struct Planner {
domain: Domain,
}
pub struct Planner(Domain);

#[derive(Debug, PartialEq)]
pub enum PlanningError {
Expand Down Expand Up @@ -65,7 +64,7 @@ impl IntoError for PlanningError {

impl Planner {
pub fn new(domain: Domain) -> Self {
Self { domain }
Self(domain)
}

fn try_task(
Expand Down Expand Up @@ -120,7 +119,7 @@ impl Planner {
t = t.with_arg(k, v)
}
let path = self
.domain
.0
.get_path(t.id(), &t.context().args)
// The user may have not have put the child task in the
// domain, in which case we need to return an error
Expand Down Expand Up @@ -152,9 +151,12 @@ impl Planner {
let tgt = serde_json::to_value(tgt)?;

let system = System::new(cur);
self.find_workflow(&system, tgt)
}

pub(crate) fn find_workflow(&self, system: &System, tgt: Value) -> Result<Workflow, Error> {
// Store the initial state and an empty plan on the stack
let mut stack = vec![(system, Workflow::default(), 0)];
let mut stack = vec![(system.clone(), Workflow::default(), 0)];

// TODO: we should merge non conflicting workflows
// for parallelism
Expand All @@ -177,7 +179,7 @@ impl Planner {
for op in distance.iter() {
// Find applicable tasks
let path = Path::new(op.path());
let matching = self.domain.at(path.to_str());
let matching = self.0.at(path.to_str());
if let Some((args, intents)) = matching {
// Calculate the target for the job path
let pointer = path.as_ref();
Expand Down
31 changes: 31 additions & 0 deletions src/worker/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ use serde_json::Value;
use std::collections::hash_map::DefaultHasher;
use std::fmt::{self, Display};
use std::hash::{Hash, Hasher};
use std::{
future::{self, Future},
pin::Pin,
};

use crate::dag::Dag;
use crate::error::Error;
use crate::system::System;
use crate::task::Task;

#[derive(Hash)]
Expand Down Expand Up @@ -86,6 +92,31 @@ impl Workflow {
pending,
}
}

pub fn is_empty(&self) -> bool {
self.dag.is_empty()
}

pub(crate) async fn execute(self, system: &mut System) -> Result<(), Error> {
let sys = system.clone();
// For now, execute the Workflow in a sequential manner
// TODO: Implement parallel execution
*system = self
.dag
.fold::<Pin<Box<dyn Future<Output = Result<System, Error>>>>>(
Box::pin(future::ready(Ok(sys))),
|sys, Action { task, .. }| {
let task = task.clone();
Box::pin(async move {
let mut sys = sys.await?;
task.run(&mut sys).await?;
Ok(sys)
})
},
)
.await?;
Ok(())
}
}

impl Display for Workflow {
Expand Down

0 comments on commit b2d1336

Please sign in to comment.