Skip to content

Commit

Permalink
Translated remaining UpdateGamma() (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Sep 17, 2024
1 parent 0c12669 commit 098449e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 40 deletions.
11 changes: 8 additions & 3 deletions R/func_MCMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,14 @@ func_MCMC <- function(survObj, hyperpar, initial,
# update gamma (latent indicators of variable selection)
# browser()
sampleGam <- UpdateGamma(survObj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp)
if (is(sampleGam$gamma.ini, "matrix") && method != "Pooled") {
# Workaround because C++ outputs list elements as matrices
sampleGam <- lapply(sampleGam, function(x) as.list(as.data.frame(x)))
if (is(sampleGam$gamma.ini, "matrix")) {
# TEMP Workaround because C++ outputs list elements as matrices and UpdateRPlee11 expects lists (until it's translated)
sampleGam <- lapply(sampleGam, function(x) apply(x, 1, list))
if (!(method == "Pooled" && MRF_G)) {
sampleGam <- lapply(sampleGam, function(x) lapply(x, unlist))
} else {
sampleGam <- lapply(sampleGam, unlist)
}
}
gamma.ini <- ini$gamma.ini <- sampleGam$gamma.ini

Expand Down
3 changes: 1 addition & 2 deletions R/updatePara.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
UpdateGamma <- function(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b, cpp = FALSE) {
# Update latent variable selection indicators gamma with either independent Bernoulli prior
# (standard approaches) or with MRF prior.
if (cpp && MRF_G) {
warning("This is not yet fully implemented. Please use cpp = FALSE for production")
if (cpp) {
return(UpdateGamma_cpp(sobj, hyperpar, ini, S, method, MRF_G, MRF_2b))
}
p <- sobj$p
Expand Down
53 changes: 25 additions & 28 deletions src/UpdateGamma_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Rcpp::List UpdateGamma_cpp(
double cb = Rcpp::as<double>(hyperpar["cb"]);
double pi = Rcpp::as<double>(hyperpar["pi.ga"]);
double a = Rcpp::as<double>(hyperpar["a"]);
arma::vec b = Rcpp::as<arma::vec>(hyperpar["b"]);
arma::rowvec b = arma::rowvec(p, arma::fill::value(Rcpp::as<double>(hyperpar["b"])));

arma::mat beta_ini = list_to_matrix(ini["beta.ini"]);
arma::mat gamma_ini = list_to_matrix(ini["gamma.ini"]);
Expand All @@ -50,24 +50,25 @@ Rcpp::List UpdateGamma_cpp(
// TODO: test this case. Not default!
for (arma::uword g = 0; g < S; g++) {
arma::uvec g_seq = arma::regspace<arma::uvec>(g * p, g * p + p - 1); // equivalent to (g - 1) * p + (1:p)
G_ini.submat(g_seq, g_seq) *= b[0];
G_ini.submat(g_seq, g_seq) *= b[0]; // TODO: replace [] with () for bounds check
}
for (arma::uword g = 0; g < S - 1; g++) {
for (arma::uword r = g; r < S - 1; r++) {
arma::uvec g_seq = arma::regspace<arma::uvec>(g * p, g * p + p - 1); // equivalent to (g - 1) * p + (1:p)
arma::uvec r_seq = arma::regspace<arma::uvec>(r * p + p, r * p + 2 * p - 1); // equivalent to r * p + (1:p)
G_ini.submat(g_seq, r_seq) = b[1] * G_ini.submat(r_seq, g_seq);
G_ini.submat(r_seq, g_seq) *= b[1];
G_ini.submat(g_seq, r_seq) = b[1] * G_ini.submat(r_seq, g_seq); // TODO: replace [] with () for bounds check1
G_ini.submat(r_seq, g_seq) *= b[1]; // TODO: replace [] with () for bounds check
}
}
} else if (!MRF_G) {
G_ini = G_ini * b;
G_ini *= b(0);
}

arma::mat post_gamma(p, S, arma::fill::zeros);
if (method == "Pooled" && MRF_G) {
arma::vec post_gamma = arma::zeros<arma::vec>(p);
for (arma::uword j = 0; j < p; j++) {
double beta = beta_ini(j);
// FIXME: why are indices flipped here w.r.t. the other cases?
double beta = beta_ini(0, j);

arma::vec ga_prop1 = gamma_ini.t();
arma::vec ga_prop0 = gamma_ini.t();
Expand All @@ -80,17 +81,15 @@ Rcpp::List UpdateGamma_cpp(
double w_max = std::max(wa, wb);
double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max));

gamma_ini(j) = R::runif(0, 1) < pg;
post_gamma(j) = pg;
gamma_ini(0, j) = R::runif(0, 1) < pg;
post_gamma(j, 0) = pg;
}
Rcpp::List out = Rcpp::List::create(
Rcpp::Named("gamma.ini") = gamma_ini,
Rcpp::Named("post.gamma") = post_gamma
Rcpp::Named("post.gamma") = arma::trans(post_gamma)
);
return out;
} else {
arma::mat post_gamma(p, S, arma::fill::zeros);

if (MRF_G) {
for (arma::uword g = 0; g < S; g++) { // loop through subgroups
for (arma::uword j = 0; j < p; j++) {
Expand All @@ -105,30 +104,28 @@ Rcpp::List UpdateGamma_cpp(
} else { // CoxBVS-SL or Sub-struct model
for (arma::uword g = 0; g < S; g++) {
for (arma::uword j = 0; j < p; j++) {
// beta <- (beta.ini[[g]])[j]
double beta = beta_ini(j, g);

// ga.prop1 <- ga.prop0 <- gamma.ini # gamma with gamma_g,j=1 or 0
// ga.prop1[[g]][j] <- 1
// ga.prop0[[g]][j] <- 0
// ga.prop1 <- unlist(ga.prop1)
// ga.prop0 <- unlist(ga.prop0)
// TODO: refactor. same as above
arma::colvec ga_prop1 = gamma_ini;
arma::colvec ga_prop0 = gamma_ini;
ga_prop1(j, g) = 1;
ga_prop0(j, g) = 0;

// wa <- (a * sum(ga.prop1) + t(ga.prop1) %*% G.ini %*% ga.prop1) +
// dnorm(beta, mean = 0, sd = tau * cb, log = TRUE)
// wb <- (a * sum(ga.prop0) + t(ga.prop0) %*% G.ini %*% ga.prop0) +
// dnorm(beta, mean = 0, sd = tau, log = TRUE)
double wa = calc_wa_wb(ga_prop1, G_ini, beta, a, tau * cb);
double wb = calc_wa_wb(ga_prop0, G_ini, beta, a, tau);

// w_max <- max(wa, wb)
// pg <- exp(wa - w_max) / (exp(wa - w_max) + exp(wb - w_max))
double w_max = std::max(wa, wb);
double pg = std::exp(wa - w_max) / (std::exp(wa - w_max) + std::exp(wb - w_max));

// gamma.ini[[g]][j] <- as.numeric(runif(1) < pg)
// post.gamma[[g]][j] <- pg
gamma_ini(j, g) = R::runif(0, 1) < pg;
post_gamma(j, g) = pg;
}
}
}
Rcpp::List out = Rcpp::List::create(
Rcpp::Named("gamma.ini") = gamma_ini,
Rcpp::Named("post.gamma") = post_gamma
Rcpp::Named("gamma.ini") = arma::trans(gamma_ini),
Rcpp::Named("post.gamma") = arma::trans(post_gamma)
);
return out;
}
Expand Down
31 changes: 24 additions & 7 deletions tests/testthat/test-cpp_translation.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ data_2S <- list(

## Initial value: null model without covariates
initial <- list("gamma.ini" = rep(0, ncol(data$X)))
initial_2S <- list(
initial,
list("gamma.ini" = rep(0, ncol(data_2S[[2]]$X)))
)


# Prior parameters
hyperPooled = list(
Expand All @@ -34,20 +39,27 @@ hyperPooled_2S$G <- Matrix::bdiag(simData$G, simData$G)

set.seed(715074)
BayesSurvive_wrap <- function(
data, initial, hyper, model = "Pooled", use_cpp = FALSE, n_iter = 30
data, initial, hyper, model = "Pooled", use_cpp = FALSE, n_iter = 10,
MRF_G = TRUE, verbose = FALSE
) {
suppressWarnings(
BayesSurvive(
survObj = data, model.type = model, MRF.G = TRUE, verbose = FALSE,
hyperpar = hyper, initial = initial, nIter = n_iter,
burnin = floor(n_iter / 2), cpp = use_cpp
)
if (!MRF_G) {
data <- list(data)
hyper$lambda <- 3 # TODO: mandatory for !MRG.G? Add validation!
hyper$nu0 <- 0.05
hyper$nu1 <- 5
}
BayesSurvive(
survObj = data, model.type = model, MRF.G = MRF_G, verbose = verbose,
hyperpar = hyper, initial = initial, nIter = n_iter,
burnin = floor(n_iter / 2), cpp = use_cpp
)
}
fit_R <- BayesSurvive_wrap(data, initial, hyperPooled)
fit_C <- BayesSurvive_wrap(data, initial, hyperPooled, use_cpp = TRUE)
fit_R2S <- BayesSurvive_wrap(data_2S, initial, hyperPooled_2S, "CoxBVSSL")
fit_C2S <- BayesSurvive_wrap(data_2S, initial, hyperPooled_2S, "CoxBVSSL", use_cpp = TRUE)
fit_R_noMRFG <- BayesSurvive_wrap(data, initial, hyperPooled, MRF_G = FALSE)
fit_C_noMRFG <- BayesSurvive_wrap(data, initial, hyperPooled, MRF_G = FALSE, use_cpp = TRUE)

test_that("R and C++ objects are similar", {
expect_equal(fit_R$call, fit_C$call)
Expand All @@ -60,4 +72,9 @@ test_that("R and C++ objects are similar", {
for (obj in names(fit_R2S$output)[2]) {
expect_equal(fit_R2S$output[[obj]], fit_C2S$output[[obj]], tolerance = 1)
}
expect_equal(fit_R_noMRFG$call, fit_C_noMRFG$call)
expect_equal(fit_R_noMRFG$input, fit_C_noMRFG$input)
for (obj in names(fit_R_noMRFG$output)[2]) {
expect_equal(fit_R_noMRFG$output[[obj]], fit_C_noMRFG$output[[obj]], tolerance = 1)
}
})

0 comments on commit 098449e

Please sign in to comment.