Skip to content

Commit

Permalink
feat: Add more sample stats about divergences
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 21, 2023
1 parent fc83b9f commit d736a9b
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/cpu_potential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
let div_info = DivergenceInfo {
logp_function_error: Some(Box::new(logp_error)),
start_location: Some(start.q.clone()),
start_gradient: Some(start.grad.clone()),
start_momentum: Some(start.p.clone()),
end_location: None,
start_idx_in_trajectory: Some(start.idx_in_trajectory),
end_idx_in_trajectory: None,
Expand All @@ -142,7 +144,9 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
let divergence_info = DivergenceInfo {
logp_function_error: None,
start_location: Some(start.q.clone()),
start_gradient: Some(start.grad.clone()),
end_location: Some(out.q.clone()),
start_momentum: Some(out.p.clone()),
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_idx_in_trajectory: Some(out.index_in_trajectory()),
energy_error: Some(energy_error),
Expand Down
4 changes: 4 additions & 0 deletions src/cpu_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ impl crate::nuts::State for State {
out.copy_from_slice(&self.grad);
}

fn write_momentum(&self, out: &mut [f64]) {
out.copy_from_slice(&self.p);
}

fn energy(&self) -> f64 {
self.kinetic_energy + self.potential_energy
}
Expand Down
148 changes: 147 additions & 1 deletion src/nuts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow2::array::{MutableFixedSizeListArray, TryPush};
use arrow2::array::{MutableFixedSizeListArray, MutableUtf8Array, TryPush};
#[cfg(feature = "arrow")]
use arrow2::{
array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray},
Expand Down Expand Up @@ -37,7 +37,9 @@ pub type Result<T> = std::result::Result<T, NutsError>;
/// failed)
#[derive(Debug)]
pub struct DivergenceInfo {
pub start_momentum: Option<Box<[f64]>>,
pub start_location: Option<Box<[f64]>>,
pub start_gradient: Option<Box<[f64]>>,
pub end_location: Option<Box<[f64]>>,
pub energy_error: Option<f64>,
pub end_idx_in_trajectory: Option<i64>,
Expand Down Expand Up @@ -152,6 +154,9 @@ pub trait State: Clone + Debug {
/// Write the gradient stored in the state to a different location
fn write_gradient(&self, out: &mut [f64]);

/// Write the momentum in the state to a different location
fn write_momentum(&self, out: &mut [f64]);

/// Compute the termination criterion for NUTS
fn is_turning(&self, other: &Self) -> bool;

Expand Down Expand Up @@ -523,6 +528,11 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
hamiltonian: <H::Stats as ArrowRow>::Builder,
adapt: <A::Stats as ArrowRow>::Builder,
diverging: MutableBooleanArray,
divergence_start: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_start_grad: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_end: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_momentum: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_msg: Option<MutableUtf8Array<i64>>,
}

#[cfg(feature = "arrow")]
Expand All @@ -548,6 +558,40 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
None
};

let (div_start, div_start_grad, div_end, div_mom, div_msg) = if settings.store_divergences {
let start_location_prim = MutablePrimitiveArray::new();
let start_location_list =
MutableFixedSizeListArray::new_with_field(start_location_prim, "item", false, dim);

let start_grad_prim = MutablePrimitiveArray::new();
let start_grad_list =
MutableFixedSizeListArray::new_with_field(start_grad_prim, "item", false, dim);

let end_location_prim = MutablePrimitiveArray::new();
let end_location_list =
MutableFixedSizeListArray::new_with_field(end_location_prim, "item", false, dim);

let momentum_location_prim = MutablePrimitiveArray::new();
let momentum_location_list = MutableFixedSizeListArray::new_with_field(
momentum_location_prim,
"item",
false,
dim,
);

let msg_list = MutableUtf8Array::new();

(
Some(start_location_list),
Some(start_grad_list),
Some(end_location_list),
Some(momentum_location_list),
Some(msg_list),
)
} else {
(None, None, None, None, None)
};

Self {
depth: MutablePrimitiveArray::with_capacity(capacity),
maxdepth_reached: MutableBooleanArray::with_capacity(capacity),
Expand All @@ -561,6 +605,11 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
diverging: MutableBooleanArray::with_capacity(capacity),
divergence_start: div_start,
divergence_start_grad: div_start_grad,
divergence_end: div_end,
divergence_momentum: div_mom,
divergence_msg: div_msg,
}
}
}
Expand Down Expand Up @@ -601,6 +650,58 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
.unwrap();
}

let info_option = value.divergence_info();
if let Some(div_start) = self.divergence_start.as_mut() {
div_start
.try_push(info_option.and_then(|info| {
info.start_location
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

let info_option = value.divergence_info();
if let Some(div_grad) = self.divergence_start_grad.as_mut() {
div_grad
.try_push(info_option.and_then(|info| {
info.start_gradient
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_end) = self.divergence_end.as_mut() {
div_end
.try_push(info_option.and_then(|info| {
info.end_location
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_mom) = self.divergence_momentum.as_mut() {
div_mom
.try_push(info_option.and_then(|info| {
info.start_momentum
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_msg) = self.divergence_msg.as_mut() {
div_msg
.try_push(info_option.and_then(|info| {
info.logp_function_error
.as_ref()
.map(|err| format!("{}", err))
}))
.unwrap();
}

self.hamiltonian.append_value(&value.potential_stats);
self.adapt.append_value(&value.strategy_stats);
}
Expand Down Expand Up @@ -655,6 +756,51 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
arrays.push(unconstrained.as_box());
}

if let Some(mut div_start) = self.divergence_start.take() {
fields.push(Field::new(
"divergence_start",
div_start.data_type().clone(),
true,
));
arrays.push(div_start.as_box());
}

if let Some(mut div_start_grad) = self.divergence_start_grad.take() {
fields.push(Field::new(
"divergence_start_gradient",
div_start_grad.data_type().clone(),
true,
));
arrays.push(div_start_grad.as_box());
}

if let Some(mut div_end) = self.divergence_end.take() {
fields.push(Field::new(
"divergence_end",
div_end.data_type().clone(),
true,
));
arrays.push(div_end.as_box());
}

if let Some(mut div_mom) = self.divergence_momentum.take() {
fields.push(Field::new(
"divergence_momentum",
div_mom.data_type().clone(),
true,
));
arrays.push(div_mom.as_box());
}

if let Some(mut div_msg) = self.divergence_msg.take() {
fields.push(Field::new(
"divergence_message",
div_msg.data_type().clone(),
true,
));
arrays.push(div_msg.as_box());
}

Some(StructArray::new(DataType::Struct(fields), arrays, None))
}
}
Expand Down

0 comments on commit d736a9b

Please sign in to comment.