From 2754b96424c0c5694fd32839ade7906f996f75ee Mon Sep 17 00:00:00 2001 From: Waldir Leoncio Date: Tue, 17 Sep 2024 14:43:19 +0200 Subject: [PATCH] Created `calc_pg()` (#11) --- src/UpdateGamma_cpp.cpp | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/UpdateGamma_cpp.cpp b/src/UpdateGamma_cpp.cpp index 21bbed3..072f050 100644 --- a/src/UpdateGamma_cpp.cpp +++ b/src/UpdateGamma_cpp.cpp @@ -14,6 +14,21 @@ double calc_wa_wb( return product + jitter; } +double calc_pg( + const arma::vec& ga_prop1, + const arma::vec& ga_prop0, + const arma::mat& G_ini, + const double beta, + const double a, + const double tau, + const double cb +) { + double wa = calc_wa_wb(ga_prop1, G_ini, beta, a, tau * cb); + double wb = calc_wa_wb(ga_prop0, G_ini, beta, a, tau); + double w_max = std::max(wa, wb); + return std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max)); +} + // [[Rcpp::export]] Rcpp::List UpdateGamma_cpp( const Rcpp::List sobj, @@ -75,11 +90,7 @@ Rcpp::List UpdateGamma_cpp( ga_prop1(j) = 1; ga_prop0(j) = 0; - double wa = calc_wa_wb(ga_prop1, G_ini, beta, a, tau * cb); - double wb = calc_wa_wb(ga_prop0, G_ini, beta, a, tau); - - double w_max = std::max(wa, wb); - double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max)); + double pg = calc_pg(ga_prop1, ga_prop0, G_ini, beta, a, tau, cb); gamma_ini(0, j) = R::runif(0, 1) < pg; post_gamma(0, j) = pg; @@ -106,18 +117,12 @@ Rcpp::List UpdateGamma_cpp( for (arma::uword j = 0; j < p; j++) { double beta = beta_ini(j, g); - // TODO: refactor. same as above arma::vec ga_prop1 = gamma_ini; arma::vec ga_prop0 = gamma_ini; ga_prop1(j, g) = 1; ga_prop0(j, g) = 0; - double wa = calc_wa_wb(ga_prop1, G_ini, beta, a, tau * cb); - double wb = calc_wa_wb(ga_prop0, G_ini, beta, a, tau); - - double w_max = std::max(wa, wb); - double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max)); - + double pg = calc_pg(ga_prop1, ga_prop0, G_ini, beta, a, tau, cb); gamma_ini(j, g) = R::runif(0, 1) < pg; post_gamma(g, j) = pg; }