diff --git a/R/func_MCMC.R b/R/func_MCMC.R index 00730af..ce9b352 100755 --- a/R/func_MCMC.R +++ b/R/func_MCMC.R @@ -178,9 +178,14 @@ func_MCMC <- function(survObj, hyperpar, initial, # update gamma (latent indicators of variable selection) # browser() sampleGam <- UpdateGamma(survObj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp) - if (is(sampleGam$gamma.ini, "matrix") && method != "Pooled") { - # Workaround because C++ outputs list elements as matrices - sampleGam <- lapply(sampleGam, function(x) as.list(as.data.frame(x))) + if (is(sampleGam$gamma.ini, "matrix")) { + # TEMP Workaround because C++ outputs list elements as matrices and UpdateRPlee11 expects lists (until it's translated) + sampleGam <- lapply(sampleGam, function(x) apply(x, 1, list)) + if (!(method == "Pooled" && MRF_G)) { + sampleGam <- lapply(sampleGam, function(x) lapply(x, unlist)) + } else { + sampleGam <- lapply(sampleGam, unlist) + } } gamma.ini <- ini$gamma.ini <- sampleGam$gamma.ini diff --git a/R/updatePara.R b/R/updatePara.R index b91e063..e030e52 100755 --- a/R/updatePara.R +++ b/R/updatePara.R @@ -28,8 +28,7 @@ UpdateGamma <- function(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp = FALSE) { # Update latent variable selection indicators gamma with either independent Bernoulli prior # (standard approaches) or with MRF prior. - if (cpp && MRF_G) { - warning("This is not yet fully implemented. Please use cpp = FALSE for production") + if (cpp) { return(UpdateGamma_cpp(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b)) } p <- sobj$p diff --git a/src/UpdateGamma_cpp.cpp b/src/UpdateGamma_cpp.cpp index bf7a399..0207916 100644 --- a/src/UpdateGamma_cpp.cpp +++ b/src/UpdateGamma_cpp.cpp @@ -32,7 +32,7 @@ Rcpp::List UpdateGamma_cpp( double cb = Rcpp::as(hyperpar["cb"]); double pi = Rcpp::as(hyperpar["pi.ga"]); double a = Rcpp::as(hyperpar["a"]); - arma::vec b = Rcpp::as(hyperpar["b"]); + arma::rowvec b = arma::rowvec(p, arma::fill::value(Rcpp::as(hyperpar["b"]))); arma::mat beta_ini = list_to_matrix(ini["beta.ini"]); arma::mat gamma_ini = list_to_matrix(ini["gamma.ini"]); @@ -50,24 +50,25 @@ Rcpp::List UpdateGamma_cpp( // TODO: test this case. Not default! for (arma::uword g = 0; g < S; g++) { arma::uvec g_seq = arma::regspace(g * p, g * p + p - 1); // equivalent to (g - 1) * p + (1:p) - G_ini.submat(g_seq, g_seq) *= b[0]; + G_ini.submat(g_seq, g_seq) *= b[0]; // TODO: replace [] with () for bounds check } for (arma::uword g = 0; g < S - 1; g++) { for (arma::uword r = g; r < S - 1; r++) { arma::uvec g_seq = arma::regspace(g * p, g * p + p - 1); // equivalent to (g - 1) * p + (1:p) arma::uvec r_seq = arma::regspace(r * p + p, r * p + 2 * p - 1); // equivalent to r * p + (1:p) - G_ini.submat(g_seq, r_seq) = b[1] * G_ini.submat(r_seq, g_seq); - G_ini.submat(r_seq, g_seq) *= b[1]; + G_ini.submat(g_seq, r_seq) = b[1] * G_ini.submat(r_seq, g_seq); // TODO: replace [] with () for bounds check1 + G_ini.submat(r_seq, g_seq) *= b[1]; // TODO: replace [] with () for bounds check } } } else if (!MRF_G) { - G_ini = G_ini * b; + G_ini *= b(0); } + arma::mat post_gamma(p, S, arma::fill::zeros); if (method == "Pooled" && MRF_G) { - arma::vec post_gamma = arma::zeros(p); for (arma::uword j = 0; j < p; j++) { - double beta = beta_ini(j); + // FIXME: why are indices flipped here w.r.t. the other cases? + double beta = beta_ini(0, j); arma::vec ga_prop1 = gamma_ini.t(); arma::vec ga_prop0 = gamma_ini.t(); @@ -80,17 +81,15 @@ Rcpp::List UpdateGamma_cpp( double w_max = std::max(wa, wb); double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max)); - gamma_ini(j) = R::runif(0, 1) < pg; - post_gamma(j) = pg; + gamma_ini(0, j) = R::runif(0, 1) < pg; + post_gamma(j, 0) = pg; } Rcpp::List out = Rcpp::List::create( Rcpp::Named("gamma.ini") = gamma_ini, - Rcpp::Named("post.gamma") = post_gamma + Rcpp::Named("post.gamma") = arma::trans(post_gamma) ); return out; } else { - arma::mat post_gamma(p, S, arma::fill::zeros); - if (MRF_G) { for (arma::uword g = 0; g < S; g++) { // loop through subgroups for (arma::uword j = 0; j < p; j++) { @@ -105,30 +104,28 @@ Rcpp::List UpdateGamma_cpp( } else { // CoxBVS-SL or Sub-struct model for (arma::uword g = 0; g < S; g++) { for (arma::uword j = 0; j < p; j++) { - // beta <- (beta.ini[[g]])[j] + double beta = beta_ini(j, g); - // ga.prop1 <- ga.prop0 <- gamma.ini # gamma with gamma_g,j=1 or 0 - // ga.prop1[[g]][j] <- 1 - // ga.prop0[[g]][j] <- 0 - // ga.prop1 <- unlist(ga.prop1) - // ga.prop0 <- unlist(ga.prop0) + // TODO: refactor. same as above + arma::colvec ga_prop1 = gamma_ini; + arma::colvec ga_prop0 = gamma_ini; + ga_prop1(j, g) = 1; + ga_prop0(j, g) = 0; - // wa <- (a * sum(ga.prop1) + t(ga.prop1) %*% G.ini %*% ga.prop1) + - // dnorm(beta, mean = 0, sd = tau * cb, log = TRUE) - // wb <- (a * sum(ga.prop0) + t(ga.prop0) %*% G.ini %*% ga.prop0) + - // dnorm(beta, mean = 0, sd = tau, log = TRUE) + double wa = calc_wa_wb(ga_prop1, G_ini, beta, a, tau * cb); + double wb = calc_wa_wb(ga_prop0, G_ini, beta, a, tau); - // w_max <- max(wa, wb) - // pg <- exp(wa - w_max) / (exp(wa - w_max) + exp(wb - w_max)) + double w_max = std::max(wa, wb); + double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max)); - // gamma.ini[[g]][j] <- as.numeric(runif(1) < pg) - // post.gamma[[g]][j] <- pg + gamma_ini(j, g) = R::runif(0, 1) < pg; + post_gamma(j, g) = pg; } } } Rcpp::List out = Rcpp::List::create( - Rcpp::Named("gamma.ini") = gamma_ini, - Rcpp::Named("post.gamma") = post_gamma + Rcpp::Named("gamma.ini") = arma::trans(gamma_ini), + Rcpp::Named("post.gamma") = arma::trans(post_gamma) ); return out; } diff --git a/tests/testthat/test-cpp_translation.R b/tests/testthat/test-cpp_translation.R index b74c6d1..d62897f 100644 --- a/tests/testthat/test-cpp_translation.R +++ b/tests/testthat/test-cpp_translation.R @@ -16,6 +16,11 @@ data_2S <- list( ## Initial value: null model without covariates initial <- list("gamma.ini" = rep(0, ncol(data$X))) +initial_2S <- list( + initial, + list("gamma.ini" = rep(0, ncol(data_2S[[2]]$X))) +) + # Prior parameters hyperPooled = list( @@ -34,20 +39,27 @@ hyperPooled_2S$G <- Matrix::bdiag(simData$G, simData$G) set.seed(715074) BayesSurvive_wrap <- function( - data, initial, hyper, model = "Pooled", use_cpp = FALSE, n_iter = 30 + data, initial, hyper, model = "Pooled", use_cpp = FALSE, n_iter = 10, + MRF_G = TRUE, verbose = FALSE ) { - suppressWarnings( - BayesSurvive( - survObj = data, model.type = model, MRF.G = TRUE, verbose = FALSE, - hyperpar = hyper, initial = initial, nIter = n_iter, - burnin = floor(n_iter / 2), cpp = use_cpp - ) + if (!MRF_G) { + data <- list(data) + hyper$lambda <- 3 # TODO: mandatory for !MRG.G? Add validation! + hyper$nu0 <- 0.05 + hyper$nu1 <- 5 + } + BayesSurvive( + survObj = data, model.type = model, MRF.G = MRF_G, verbose = verbose, + hyperpar = hyper, initial = initial, nIter = n_iter, + burnin = floor(n_iter / 2), cpp = use_cpp ) } fit_R <- BayesSurvive_wrap(data, initial, hyperPooled) fit_C <- BayesSurvive_wrap(data, initial, hyperPooled, use_cpp = TRUE) fit_R2S <- BayesSurvive_wrap(data_2S, initial, hyperPooled_2S, "CoxBVSSL") fit_C2S <- BayesSurvive_wrap(data_2S, initial, hyperPooled_2S, "CoxBVSSL", use_cpp = TRUE) +fit_R_noMRFG <- BayesSurvive_wrap(data, initial, hyperPooled, MRF_G = FALSE) +fit_C_noMRFG <- BayesSurvive_wrap(data, initial, hyperPooled, MRF_G = FALSE, use_cpp = TRUE) test_that("R and C++ objects are similar", { expect_equal(fit_R$call, fit_C$call) @@ -60,4 +72,9 @@ test_that("R and C++ objects are similar", { for (obj in names(fit_R2S$output)[2]) { expect_equal(fit_R2S$output[[obj]], fit_C2S$output[[obj]], tolerance = 1) } + expect_equal(fit_R_noMRFG$call, fit_C_noMRFG$call) + expect_equal(fit_R_noMRFG$input, fit_C_noMRFG$input) + for (obj in names(fit_R_noMRFG$output)[2]) { + expect_equal(fit_R_noMRFG$output[[obj]], fit_C_noMRFG$output[[obj]], tolerance = 1) + } })