Skip to content

Commit

Permalink
feat: Implement pausing and unpausing of samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Mar 18, 2024
1 parent 30ae7ef commit a24dffa
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 110 deletions.
174 changes: 85 additions & 89 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,29 +142,32 @@ class _BackgroundSampler:
def __init__(
self,
compiled_model,
sampler,
settings,
init_mean,
chains,
cores,
seed,
draws,
tune,
*,
progress_bar=True,
save_warmup=True,
return_raw_trace=False,
):
self._sampler = sampler
self._num_divs = 0
self._tune = tune
self._draws = draws
self._tune = settings.num_tune
self._draws = settings.num_draws
self._settings = settings
self._chains_tuning = chains
self._chains_finished = 0
self._chains = chains
self._sampler = sampler
self._compiled_model = compiled_model
self._save_warmup = save_warmup
self._return_raw_trace = return_raw_trace
total_draws = (self._draws + self._tune) * self._chains
self._progress = fastprogress.progress_bar(
sampler,
total=chains * (draws + tune),
range(total_draws),
total=total_draws,
display=progress_bar,
)
# fastprogress seems to reset the progress bar
Expand All @@ -176,83 +179,65 @@ def __init__(
self._pause_event = Event()
self._continue = Condition()

def show_progress():
for info in self._bar:
if info.draw == self._tune - 1:
self._chains_tuning -= 1
if info.draw == self._tune + self._draws - 1:
self._chains_finished += 1
if info.is_diverging and info.draw > self._tune:
self._num_divs += 1
if self._chains_tuning > 0:
count = self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Chains in warmup: {count}, Divergences: {divs}"
)
else:
count = self._chains - self._chains_finished
divs = self._num_divs
self._progress.comment = (
f" Sampling chains: {count}, Divergences: {divs}"
)

if timeout is not None:
current_time = time.time()
if current_time - start_time > timeout:
raise TimeoutError("Sampling did not finish")


self._thread = Thread(target=show_progress)
self._finished_draws = 0

next(self._bar)

def progress_callback(info):
if info.draw == self._tune - 1:
self._chains_tuning -= 1
if info.draw == self._tune + self._draws - 1:
self._chains_finished += 1
if info.is_diverging and info.draw > self._tune:
self._num_divs += 1
if self._chains_tuning > 0:
count = self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Chains in warmup: {count}, Divergences: {divs}"
)
else:
count = self._chains - self._chains_finished
divs = self._num_divs
self._progress.comment = (
f" Sampling chains: {count}, Divergences: {divs}"
)
try:
next(self._bar)
except StopIteration:
pass
self._finished_draws += 1

if progress_bar:
callback = progress_callback
else:
callback = None

self._sampler = compiled_model._make_sampler(
settings,
init_mean,
chains,
cores,
seed,
callback=callback,
)

def wait(self, *, timeout=None):
"""Wait until sampling is finished.
"""Wait until sampling is finished and return the trace.
KeyboardInterrupt will lead to interrupt the waiting.
This will return after `timeout` seconds even if sampling is
not finished at this point.
This resumes the sampler in case it had been paused.
"""
if self._sampler is None:
raise ValueError("Sampler is already finalized")

start_time = time.time()

try:
for info in self._bar:
if info.draw == self._tune - 1:
self._chains_tuning -= 1
if info.draw == self._tune + self._draws - 1:
self._chains_finished += 1
if info.is_diverging and info.draw > self._tune:
self._num_divs += 1
if self._chains_tuning > 0:
count = self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Chains in warmup: {count}, Divergences: {divs}"
)
else:
count = self._chains - self._chains_finished
divs = self._num_divs
self._progress.comment = (
f" Sampling chains: {count}, Divergences: {divs}"
)

if timeout is not None:
current_time = time.time()
if current_time - start_time > timeout:
raise TimeoutError("Sampling did not finish")
except KeyboardInterrupt:
pass

def finalize(self):
"""Free resources of the sampler and return the trace produced so far."""
if self._sampler is None:
raise ValueError("Sampler has already been finalized")

results = self._sampler.finalize()
self._sampler = None
self._sampler.wait(timeout)
self._sampler.finalize()
return self._extract()

def _extract(self):
results = self._sampler.extract_results()

dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()}
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
Expand All @@ -278,20 +263,30 @@ def finalize(self):
save_warmup=self._save_warmup,
)

def pause(self):
"""Pause the sampler."""
self._sampler.pause()

def resume(self):
"""Resume a paused sampler."""
self._sampler.resume()

@property
def is_finished(self):
if self._sampler is None:
return True
return self._sampler.is_finished()

def abort(self):
"""Abort sampling and return the trace produced so far."""
self._sampler.abort()
return self._extract()

def cancel(self):
"""Abort sampling and discard progress."""
if self._sampler is not None:
self._sampler.finalize()
self._sampler = None
self._sampler.abort()

def __del__(self):
self.abort()
if not self._sampler.is_empty():
self.cancel()


@overload
Expand Down Expand Up @@ -417,12 +412,13 @@ def sample(
if init_mean is None:
init_mean = np.zeros(compiled_model.n_dim)

sampler = compiled_model._make_sampler(settings, init_mean, chains, cores, seed)

sampler = _BackgroundSampler(
compiled_model,
sampler,
settings,
init_mean,
chains,
cores,
seed,
draws,
tune,
progress_bar=progress_bar,
Expand All @@ -434,11 +430,11 @@ def sample(
return sampler

try:
sampler.wait()
result = sampler.wait()
except KeyboardInterrupt:
pass
result = sampler.abort()
except:
sampler.abort()
sampler.cancel()
raise

return sampler.finalize()
return result
26 changes: 22 additions & 4 deletions src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::{
sync::{
self,
mpsc::{
channel, sync_channel, Receiver, RecvError, RecvTimeoutError, Sender, SyncSender,
TryRecvError,
channel, sync_channel, Receiver, RecvError, RecvTimeoutError, SyncSender, TryRecvError,
},
Arc,
},
Expand Down Expand Up @@ -355,7 +354,26 @@ impl SamplerControl {
result
}

pub(crate) fn is_finished(&self) -> bool {
self.poll_thread.is_finished()
pub(crate) fn try_finalize(mut self) -> SamplerWaitResult {
// We want to ignore this error, because this only
// fails if the sampler thread is done, and in this case
// we don't need to do anything.
let _ = self.resume();
let result = match self.results.try_recv() {
Err(err @ TryRecvError::Disconnected) => Err(err)
.context("Could not get sampler result in try_finalize, sampler thread is dead."),
Err(TryRecvError::Empty) => {
return SamplerWaitResult::Timeout(self);
}
Ok(result) => result,
};

drop(self.commands);
drop(self.results);
let poll_result = self.poll_thread.join();
if let Err(_) = poll_result {
return SamplerWaitResult::Result(Err(anyhow!("Sample polling thread paniced.")));
}
SamplerWaitResult::Result(result)
}
}
61 changes: 46 additions & 15 deletions src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use anyhow::{Context, Result};
use arrow2::{array::Array, datatypes::Field};
use nuts_rs::{SampleStats, SamplerArgs};
use pyo3::{
exceptions::PyValueError,
exceptions::{PyTimeoutError, PyValueError},
ffi::Py_uintptr_t,
prelude::*,
types::{PyList, PyTuple},
Expand Down Expand Up @@ -267,11 +267,23 @@ impl PySampler {
})
}

fn is_finished(&mut self) -> bool {
if let SamplerState::Running(ref control) = self.state {
control.is_finished()
} else {
true
fn is_finished(&mut self) -> PyResult<bool> {
let state = std::mem::replace(&mut self.state, SamplerState::Empty);

let SamplerState::Running(control) = state else {
let _ = std::mem::replace(&mut self.state, state);
return Ok(true);
};

match control.try_finalize() {
SamplerWaitResult::Result(result) => {
let _ = std::mem::replace(&mut self.state, SamplerState::Finished(result));
Ok(true)
}
SamplerWaitResult::Timeout(control) => {
let _ = std::mem::replace(&mut self.state, SamplerState::Running(control));
Ok(false)
}
}
}

Expand All @@ -293,10 +305,12 @@ impl PySampler {
})
}

fn wait_timeout(&mut self, py: Python<'_>, timeout_seconds: f64) -> PyResult<()> {
fn wait(&mut self, py: Python<'_>, timeout_seconds: Option<f64>) -> PyResult<()> {
py.allow_threads(|| {
let timeout =
Duration::try_from_secs_f64(timeout_seconds).context("Invalid timeout")?;
let timeout = match timeout_seconds {
Some(val) => Some(Duration::try_from_secs_f64(val).context("Invalid timeout")?),
None => None,
};

let state = std::mem::replace(&mut self.state, SamplerState::Empty);

Expand All @@ -310,11 +324,20 @@ impl PySampler {

let (final_state, retval) = loop {
let time_so_far = Instant::now().saturating_duration_since(start_time);
let Some(remaining) = timeout.checked_sub(time_so_far) else {
break (SamplerState::Running(control), Ok(()));
let next_timeout = match timeout {
Some(timeout) => {
let Some(remaining) = timeout.checked_sub(time_so_far) else {
break (
SamplerState::Running(control),
Err(PyTimeoutError::new_err(
"Timeout while waiting for sampler to finish",
)),
);
};
remaining.min(step)
}
None => step,
};
let next_timeout = remaining.min(step);
dbg!(timeout, time_so_far, remaining, next_timeout);

match control.wait_timeout(next_timeout) {
SamplerWaitResult::Result(result) => {
Expand All @@ -341,7 +364,7 @@ impl PySampler {

let SamplerState::Running(control) = state else {
let _ = std::mem::replace(&mut self.state, state);
return Err(anyhow::anyhow!("Sampler is already finalized"))?;
return Ok(());
};

let result = control.abort();
Expand All @@ -356,7 +379,7 @@ impl PySampler {

let SamplerState::Running(control) = state else {
let _ = std::mem::replace(&mut self.state, state);
return Err(anyhow::anyhow!("Sampler is already finalized"))?;
return Ok(());
};

let result = control.finalize();
Expand Down Expand Up @@ -398,6 +421,14 @@ impl PySampler {
);
Ok(list.into_py(py))
}

fn is_empty(&self) -> bool {
match self.state {
SamplerState::Running(_) => false,
SamplerState::Finished(_) => false,
SamplerState::Empty => true,
}
}
}

fn export_array(py: Python<'_>, name: String, data: Box<dyn Array>) -> PyResult<PyObject> {
Expand Down
Loading

0 comments on commit a24dffa

Please sign in to comment.