From c428236158ec96ce29d21da670205eb66b92377c Mon Sep 17 00:00:00 2001 From: Waldir Leoncio Date: Mon, 16 Sep 2024 13:19:29 +0200 Subject: [PATCH] `UpdateGamma_cpp()` `MRF_G` and different `method` (#11) --- R/func_MCMC.R | 4 ++++ R/updatePara.R | 2 +- src/UpdateGamma_cpp.cpp | 18 ++++++++---------- tests/testthat/test-cpp_translation.R | 5 +++++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/R/func_MCMC.R b/R/func_MCMC.R index 68db759..00730af 100755 --- a/R/func_MCMC.R +++ b/R/func_MCMC.R @@ -178,6 +178,10 @@ func_MCMC <- function(survObj, hyperpar, initial, # update gamma (latent indicators of variable selection) # browser() sampleGam <- UpdateGamma(survObj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp) + if (is(sampleGam$gamma.ini, "matrix") && method != "Pooled") { + # Workaround because C++ outputs list elements as matrices + sampleGam <- lapply(sampleGam, function(x) as.list(as.data.frame(x))) + } gamma.ini <- ini$gamma.ini <- sampleGam$gamma.ini # update beta (regression parameters) diff --git a/R/updatePara.R b/R/updatePara.R index d0144b6..b91e063 100755 --- a/R/updatePara.R +++ b/R/updatePara.R @@ -28,7 +28,7 @@ UpdateGamma <- function(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp = FALSE) { # Update latent variable selection indicators gamma with either independent Bernoulli prior # (standard approaches) or with MRF prior. - if (cpp && method == "Pooled" && MRF_G) { + if (cpp && MRF_G) { warning("This is not yet fully implemented. Please use cpp = FALSE for production") return(UpdateGamma_cpp(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b)) } diff --git a/src/UpdateGamma_cpp.cpp b/src/UpdateGamma_cpp.cpp index 5c0c815..bf7a399 100644 --- a/src/UpdateGamma_cpp.cpp +++ b/src/UpdateGamma_cpp.cpp @@ -39,6 +39,7 @@ Rcpp::List UpdateGamma_cpp( arma::mat G_ini = arma::zeros(p, p); if (method == "Pooled" && MRF_G) { + // G_ini is not needed if method != "Pooled" and MRF_G G_ini = Rcpp::as(hyperpar["G"]); } else if (!MRF_G) { G_ini = Rcpp::as(ini["G.ini"]); @@ -88,20 +89,17 @@ Rcpp::List UpdateGamma_cpp( ); return out; } else { - Rcpp::List post_gamma = Rcpp::List::create(S); - for (arma::uword g = 0; g < S; g++) { - post_gamma[g] = arma::zeros(p); - } + arma::mat post_gamma(p, S, arma::fill::zeros); if (MRF_G) { for (arma::uword g = 0; g < S; g++) { // loop through subgroups for (arma::uword j = 0; j < p; j++) { - // wa <- dnorm((beta.ini[[g]])[j], mean = 0, sd = cb * tau) * pi - // wb <- dnorm((beta.ini[[g]])[j], mean = 0, sd = tau) * (1 - pi) - // pgam <- wa / (wa + wb) - // u <- runif(1) - // gamma.ini[[g]][j] <- ifelse(u < pgam, 1, 0) - // post.gamma[[g]][j] <- pgam + double wa = R::dnorm(beta_ini(j, g), 0, tau * cb, true) * pi; + double wb = R::dnorm(beta_ini(j, g), 0, tau, true) * (1 - pi); + double pgam = wa / (wa + wb); + double u = R::runif(0, 1); + gamma_ini(j, g) = u < pgam; + post_gamma(j, g) = pgam; } } } else { // CoxBVS-SL or Sub-struct model diff --git a/tests/testthat/test-cpp_translation.R b/tests/testthat/test-cpp_translation.R index 819a362..89200ad 100644 --- a/tests/testthat/test-cpp_translation.R +++ b/tests/testthat/test-cpp_translation.R @@ -55,4 +55,9 @@ test_that("R and C++ objects are similar", { for (obj in names(fit_R$output)[2]) { expect_equal(fit_R$output[[obj]], fit_C$output[[obj]], tolerance = 1) } + expect_equal(fit_R2S$call, fit_C2S$call) + expect_equal(fit_R2S$input, fit_C2S$input) + for (obj in names(fit_R2S$output)[2]) { + expect_equal(fit_R2S$output[[obj]], fit_C2S$output[[obj]], tolerance = 1) + } })