Skip to content

Commit

Permalink
UpdateGamma_cpp() MRF_G and different method (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Sep 16, 2024
1 parent 2895778 commit c428236
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
4 changes: 4 additions & 0 deletions R/func_MCMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ func_MCMC <- function(survObj, hyperpar, initial,
# update gamma (latent indicators of variable selection)
# browser()
sampleGam <- UpdateGamma(survObj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp)
if (is(sampleGam$gamma.ini, "matrix") && method != "Pooled") {
# Workaround because C++ outputs list elements as matrices
sampleGam <- lapply(sampleGam, function(x) as.list(as.data.frame(x)))
}
gamma.ini <- ini$gamma.ini <- sampleGam$gamma.ini

# update beta (regression parameters)
Expand Down
2 changes: 1 addition & 1 deletion R/updatePara.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
UpdateGamma <- function(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp = FALSE) {
# Update latent variable selection indicators gamma with either independent Bernoulli prior
# (standard approaches) or with MRF prior.
if (cpp && method == "Pooled" && MRF_G) {
if (cpp && MRF_G) {
warning("This is not yet fully implemented. Please use cpp = FALSE for production")
return(UpdateGamma_cpp(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b))
}
Expand Down
18 changes: 8 additions & 10 deletions src/UpdateGamma_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Rcpp::List UpdateGamma_cpp(

arma::mat G_ini = arma::zeros<arma::mat>(p, p);
if (method == "Pooled" && MRF_G) {
// G_ini is not needed if method != "Pooled" and MRF_G
G_ini = Rcpp::as<arma::mat>(hyperpar["G"]);
} else if (!MRF_G) {
G_ini = Rcpp::as<arma::mat>(ini["G.ini"]);
Expand Down Expand Up @@ -88,20 +89,17 @@ Rcpp::List UpdateGamma_cpp(
);
return out;
} else {
Rcpp::List post_gamma = Rcpp::List::create(S);
for (arma::uword g = 0; g < S; g++) {
post_gamma[g] = arma::zeros<arma::vec>(p);
}
arma::mat post_gamma(p, S, arma::fill::zeros);

if (MRF_G) {
for (arma::uword g = 0; g < S; g++) { // loop through subgroups
for (arma::uword j = 0; j < p; j++) {
// wa <- dnorm((beta.ini[[g]])[j], mean = 0, sd = cb * tau) * pi
// wb <- dnorm((beta.ini[[g]])[j], mean = 0, sd = tau) * (1 - pi)
// pgam <- wa / (wa + wb)
// u <- runif(1)
// gamma.ini[[g]][j] <- ifelse(u < pgam, 1, 0)
// post.gamma[[g]][j] <- pgam
double wa = R::dnorm(beta_ini(j, g), 0, tau * cb, true) * pi;
double wb = R::dnorm(beta_ini(j, g), 0, tau, true) * (1 - pi);
double pgam = wa / (wa + wb);
double u = R::runif(0, 1);
gamma_ini(j, g) = u < pgam;
post_gamma(j, g) = pgam;
}
}
} else { // CoxBVS-SL or Sub-struct model
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test-cpp_translation.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,9 @@ test_that("R and C++ objects are similar", {
for (obj in names(fit_R$output)[2]) {
expect_equal(fit_R$output[[obj]], fit_C$output[[obj]], tolerance = 1)
}
expect_equal(fit_R2S$call, fit_C2S$call)
expect_equal(fit_R2S$input, fit_C2S$input)
for (obj in names(fit_R2S$output)[2]) {
expect_equal(fit_R2S$output[[obj]], fit_C2S$output[[obj]], tolerance = 1)
}
})

0 comments on commit c428236

Please sign in to comment.