Skip to content

Commit

Permalink
Merge pull request #5 from aseyboldt/performance
Browse files Browse the repository at this point in the history
Some more performance improvements
  • Loading branch information
ssoudan authored Nov 27, 2023
2 parents ab99c4d + 973ed31 commit a01b24b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
14 changes: 5 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,18 @@ pub fn run_with(
panic!("x and y must have at least one element");
}

let x0 = x[0];
// Use the middle of the time period as reference
// to prevent strong correlations between alpha and beta
let x0 = x.iter().sum::<f64>() / x.len() as f64;

let x = x.iter().map(|x| x - x0).collect::<Vec<_>>();

let model = Regression::new(x.clone(), y.clone());

// y = alpha + beta * x + noise
let guessed_beta = y.iter().sum::<f64>() / x.iter().sum::<f64>();
let guessed_beta = 0.;//y.iter().sum::<f64>() / x.iter().sum::<f64>();
let guessed_alpha = y.iter().sum::<f64>() / y.len() as f64;
let guessed_sigma = x
.iter()
.zip(y.iter())
.map(|(x, y)| (y - guessed_alpha - guessed_beta * x).powi(2))
.sum::<f64>()
.sqrt()
/ y.len() as f64;
let guessed_sigma = 1.;
let initial_position = vec![guessed_alpha, guessed_beta, guessed_sigma];
log(format!("initial_position = {:?}", initial_position).as_str());

Expand Down
4 changes: 2 additions & 2 deletions src/model/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ impl CpuLogpFunc for Regression {
let logp_beta = log_pdf_normal_propto(beta, 10f64.ln(), 0.01);
let logp_sigma = 0.; // flat prior

let mut d_logp_d_alpha = 0.;
let mut d_logp_d_beta = 0.;
let mut d_logp_d_alpha = -alpha / 100.;
let mut d_logp_d_beta = -beta / 100.;
let mut d_logp_d_sigma = 0.;

let mut logp_y = 0.;
Expand Down
2 changes: 1 addition & 1 deletion src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ where
.expect("Unrecoverable error during init");

// Burn the first x samples to get away from the initial position
for _ in 0..50 {
for _ in 0..num_tune {
sampler.draw().expect("Unrecoverable error during burning");
}

Expand Down

0 comments on commit a01b24b

Please sign in to comment.