diff --git a/src/cpu_potential.rs b/src/cpu_potential.rs index 0d0edce..1218e32 100644 --- a/src/cpu_potential.rs +++ b/src/cpu_potential.rs @@ -116,6 +116,8 @@ impl Hamiltonian for EuclideanPotential { let div_info = DivergenceInfo { logp_function_error: Some(Box::new(logp_error)), start_location: Some(start.q.clone()), + start_gradient: Some(start.grad.clone()), + start_momentum: Some(start.p.clone()), end_location: None, start_idx_in_trajectory: Some(start.idx_in_trajectory), end_idx_in_trajectory: None, @@ -142,7 +144,9 @@ impl Hamiltonian for EuclideanPotential { let divergence_info = DivergenceInfo { logp_function_error: None, start_location: Some(start.q.clone()), + start_gradient: Some(start.grad.clone()), end_location: Some(out.q.clone()), + start_momentum: Some(out.p.clone()), start_idx_in_trajectory: Some(start.index_in_trajectory()), end_idx_in_trajectory: Some(out.index_in_trajectory()), energy_error: Some(energy_error), diff --git a/src/cpu_state.rs b/src/cpu_state.rs index 42567db..4055a30 100644 --- a/src/cpu_state.rs +++ b/src/cpu_state.rs @@ -243,6 +243,10 @@ impl crate::nuts::State for State { out.copy_from_slice(&self.grad); } + fn write_momentum(&self, out: &mut [f64]) { + out.copy_from_slice(&self.p); + } + fn energy(&self) -> f64 { self.kinetic_energy + self.potential_energy } diff --git a/src/nuts.rs b/src/nuts.rs index 353e968..3531dd3 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1,4 +1,4 @@ -use arrow2::array::{MutableFixedSizeListArray, TryPush}; +use arrow2::array::{MutableFixedSizeListArray, MutableUtf8Array, TryPush}; #[cfg(feature = "arrow")] use arrow2::{ array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray}, @@ -37,7 +37,9 @@ pub type Result = std::result::Result; /// failed) #[derive(Debug)] pub struct DivergenceInfo { + pub start_momentum: Option>, pub start_location: Option>, + pub start_gradient: Option>, pub end_location: Option>, pub energy_error: Option, pub end_idx_in_trajectory: Option, @@ -152,6 +154,9 @@ pub trait State: Clone + Debug { /// Write the gradient stored in the state to a different location fn write_gradient(&self, out: &mut [f64]); + /// Write the momentum in the state to a different location + fn write_momentum(&self, out: &mut [f64]); + /// Compute the termination criterion for NUTS fn is_turning(&self, other: &Self) -> bool; @@ -523,6 +528,11 @@ pub struct StatsBuilder { hamiltonian: ::Builder, adapt: ::Builder, diverging: MutableBooleanArray, + divergence_start: Option>>, + divergence_start_grad: Option>>, + divergence_end: Option>>, + divergence_momentum: Option>>, + divergence_msg: Option>, } #[cfg(feature = "arrow")] @@ -548,6 +558,40 @@ impl StatsBuilder { None }; + let (div_start, div_start_grad, div_end, div_mom, div_msg) = if settings.store_divergences { + let start_location_prim = MutablePrimitiveArray::new(); + let start_location_list = + MutableFixedSizeListArray::new_with_field(start_location_prim, "item", false, dim); + + let start_grad_prim = MutablePrimitiveArray::new(); + let start_grad_list = + MutableFixedSizeListArray::new_with_field(start_grad_prim, "item", false, dim); + + let end_location_prim = MutablePrimitiveArray::new(); + let end_location_list = + MutableFixedSizeListArray::new_with_field(end_location_prim, "item", false, dim); + + let momentum_location_prim = MutablePrimitiveArray::new(); + let momentum_location_list = MutableFixedSizeListArray::new_with_field( + momentum_location_prim, + "item", + false, + dim, + ); + + let msg_list = MutableUtf8Array::new(); + + ( + Some(start_location_list), + Some(start_grad_list), + Some(end_location_list), + Some(momentum_location_list), + Some(msg_list), + ) + } else { + (None, None, None, None, None) + }; + Self { depth: MutablePrimitiveArray::with_capacity(capacity), maxdepth_reached: MutableBooleanArray::with_capacity(capacity), @@ -561,6 +605,11 @@ impl StatsBuilder { hamiltonian: ::new_builder(dim, settings), adapt: ::new_builder(dim, settings), diverging: MutableBooleanArray::with_capacity(capacity), + divergence_start: div_start, + divergence_start_grad: div_start_grad, + divergence_end: div_end, + divergence_momentum: div_mom, + divergence_msg: div_msg, } } } @@ -601,6 +650,58 @@ impl ArrowBuilder ArrowBuilder