Skip to content

Commit

Permalink
Added example for S > 1 (#18, #11)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Aug 19, 2024
1 parent c7306f3 commit aa4742c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/func_MCMC_graph_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Rcpp::List func_MCMC_graph_cpp(
Rcpp::List C = Rcpp::as<Rcpp::List>(ini["C.ini"]);

Rcpp::List gamma_ini_list = Rcpp::as<Rcpp::List>(ini["gamma.ini"]);
arma::vec gamma_ini = Rcpp::as<arma::vec>(gamma_ini_list[0]);
arma::vec gamma_ini = Rcpp::as<arma::vec>(gamma_ini_list[0]); // TODO: Test for S > 1, not sure the results are correct.

if (MRF_2b) {
// two different values for b in MRF prior for subgraphs G_ss and G_rs
Expand Down
32 changes: 24 additions & 8 deletions tests/testthat/test-cpp_translation.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@ dataset <- list(
"t" = simData[[1]]$time,
"di" = simData[[1]]$status
)

dataset_2S <- list(
dataset,
list(
"X" = simData[[2]]$X,
"t" = simData[[2]]$time,
"di" = simData[[2]]$status
)
)
# Run a Bayesian Cox model

## Initial value: null model without covariates
initial <- list("gamma.ini" = rep(0, ncol(dataset$X)))

initial_2S <- list(
initial,
list("gamma.ini" = rep(0, ncol(dataset_2S[[2]]$X)))
)
# Prior parameters
hyperparPooled = list(
"c0" = 2, # prior of baseline hazard
Expand All @@ -20,21 +30,27 @@ hyperparPooled = list(
"b" = 0.1, # hyperparameter in MRF prior
"G" = simData$G # hyperparameter in MRF prior
)
hyperparPooled_2S = hyperparPooled
hyperparPooled_2S$G = Matrix::bdiag(simData$G, simData$G)

# Run a 'Pooled' Bayesian Cox model with graphical learning

set.seed(715074)
BayesSurvive_wrap <- function(use_cpp = FALSE) {
BayesSurvive_wrap <- function(
data, initial, hyper, model = "Pooled", use_cpp = FALSE, n_iter = 30
) {
suppressWarnings(
BayesSurvive(
survObj = dataset, model.type = "Pooled", MRF.G = TRUE, verbose = TRUE,
hyperpar = hyperparPooled, initial = initial, nIter = 100, burnin = 100,
cpp = use_cpp
survObj = data, model.type = model, MRF.G = TRUE, verbose = FALSE,
hyperpar = hyper, initial = initial, nIter = n_iter,
burnin = floor(n_iter / 2), cpp = use_cpp
)
)
}
fit_R <- BayesSurvive_wrap(use_cpp = FALSE)
fit_C <- BayesSurvive_wrap(use_cpp = TRUE)
fit_R <- BayesSurvive_wrap(dataset, initial, hyperparPooled)
fit_C <- BayesSurvive_wrap(dataset, initial, hyperparPooled, use_cpp = TRUE)
fit_R2S <- BayesSurvive_wrap(dataset_2S, initial_2S, hyperparPooled_2S, "CoxBVSSL")
fit_C2S <- BayesSurvive_wrap(dataset_2S, initial_2S, hyperparPooled_2S, "CoxBVSSL", use_cpp = TRUE)

test_that("R and C++ objects are similar", {
expect_equal(fit_R$call, fit_C$call)
Expand Down

0 comments on commit aa4742c

Please sign in to comment.