diff --git a/DESCRIPTION b/DESCRIPTION index 5e7b2bc7b..2e3203327 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,8 +2,8 @@ Package: brms Encoding: UTF-8 Type: Package Title: Bayesian Regression Models using 'Stan' -Version: 2.22.8 -Date: 2024-12-05 +Version: 2.22.9 +Date: 2025-02-01 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), diff --git a/NAMESPACE b/NAMESPACE index 98da81f54..73546e6a0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -410,6 +410,7 @@ export(dhurdle_negbinomial) export(dhurdle_poisson) export(dinv_gaussian) export(dirichlet) +export(dirichlet_multinomial) export(dlogistic_normal) export(dmulti_normal) export(dmulti_student_t) diff --git a/NEWS.md b/NEWS.md index 6c9e3a3b4..2f4de2041 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,8 +5,10 @@ * Fit extended-support Beta models via family `xbeta` thanks to Ioannis Kosmidis. (#1698) * Add a `seed` argument to `loo_R2` thanks to Marco Colombo. (#1713) +* Add family `dirichlet_multinomial` to fit overdispersed +multinomial data thanks to Tom Peatman. (#1729) * Add the `int_step` R function to match the corresponding Stan -function. (#1734) +function thanks to Daniel Sabanes Bove. (#1734) ### Bug Fixes @@ -18,7 +20,7 @@ models, `log_lik` will not use `option(mc.cores)` anymore. These changes may be reverted once the underlying causes of this issue have been fixed. (#1658) * Align the definition of the R function `step()` with the definition in Stan, -such that `step(0) == 1` (#1734) +such that `step(0) == 1` thanks to Daniel Sabanes Bov. (#1734) ### Other Changes diff --git a/R/distributions.R b/R/distributions.R index fefa76a34..09a26d9bf 100644 --- a/R/distributions.R +++ b/R/distributions.R @@ -2185,6 +2185,24 @@ dmultinomial <- function(x, eta, log = FALSE) { out } +# density of the dirichlet-multinomial distribution with the softmax transform +# @param x positive integers not greater than ncat +# @param eta the linear predictor (of length or ncol ncat) +# @param phi the dispersion parameter (i.e., sum of dirichlet alphas) +# @param log return values on the log scale? +ddirichletmultinomial <- function(x, eta, phi, log = FALSE) { + require_package("extraDistr") + if (is.null(dim(eta))) { + eta <- matrix(eta, nrow = 1) + } + if (length(dim(eta)) != 2L) { + stop2("eta must be a numeric vector or matrix.") + } + alpha <- softmax(eta) * phi + size <- sum(x) + extraDistr::ddirmnom(x, size = size, alpha = alpha, log = log) +} + # density of the cumulative distribution # # @param x Integer vector containing response category indices to return the diff --git a/R/families.R b/R/families.R index 68e97b3f2..fe45bfb22 100644 --- a/R/families.R +++ b/R/families.R @@ -18,7 +18,7 @@ #' \code{inverse.gaussian}, \code{exponential}, \code{weibull}, #' \code{frechet}, \code{Beta}, \code{dirichlet}, \code{von_mises}, #' \code{asym_laplace}, \code{gen_extreme_value}, \code{categorical}, -#' \code{multinomial}, \code{cumulative}, \code{cratio}, \code{sratio}, +#' \code{multinomial}, \code{dirichlet_multinomial}, \code{cumulative}, \code{cratio}, \code{sratio}, #' \code{acat}, \code{hurdle_poisson}, \code{hurdle_negbinomial}, #' \code{hurdle_gamma}, \code{hurdle_lognormal}, \code{hurdle_cumulative}, #' \code{zero_inflated_binomial}, \code{zero_inflated_beta_binomial}, @@ -51,8 +51,9 @@ #' consecutive thresholds to the same value, and #' \code{"sum_to_zero"} ensures the thresholds sum to zero. #' @param refcat Optional name of the reference response category used in -#' \code{categorical}, \code{multinomial}, \code{dirichlet} and -#' \code{logistic_normal} models. If \code{NULL} (the default), the first +#' \code{categorical}, \code{multinomial}, \code{dirichlet}, +#' \code{dirichlet_multinomial} and \code{logistic_normal} models. +#' If \code{NULL} (the default), the first #' category is used as the reference. If \code{NA}, all categories will be #' predicted, which requires strong priors or carefully specified predictor #' terms in order to lead to an identified model. @@ -76,8 +77,9 @@ #' can be used for binary regression (i.e., most commonly logistic #' regression).} #' -#' \item{Families \code{categorical} and \code{multinomial} can be used for -#' multi-logistic regression when there are more than two possible outcomes.} +#' \item{Families \code{categorical}, \code{multinomial} and +#' \code{dirichlet_multinomial} can be used for multi-logistic regression +#' when there are more than two possible outcomes.} #' #' \item{Families \code{cumulative}, \code{cratio} ('continuation ratio'), #' \code{sratio} ('stopping ratio'), and \code{acat} ('adjacent category') @@ -150,8 +152,8 @@ #' \code{acat}, and \code{hurdle_cumulative} support \code{logit}, #' \code{probit}, \code{probit_approx}, \code{cloglog}, and \code{cauchit}.} #' -#' \item{Families \code{categorical}, \code{multinomial}, and \code{dirichlet} -#' support \code{logit}.} +#' \item{Families \code{categorical}, \code{multinomial}, +#' \code{dirichlet_multinomial} and \code{dirichlet} support \code{logit}.} #' #' \item{Families \code{Gamma}, \code{weibull}, \code{exponential}, #' \code{frechet}, and \code{hurdle_gamma} support @@ -812,6 +814,15 @@ multinomial <- function(link = "logit", refcat = NULL) { .brmsfamily("multinomial", link = link, slink = slink, refcat = refcat) } +#' @rdname brmsfamily +#' @export +dirichlet_multinomial <- function(link = "logit", link_phi = "log", + refcat = NULL) { + slink <- substitute(link) + .brmsfamily("dirichlet_multinomial", link = link, slink = slink, + link_phi = link_phi, refcat = refcat) +} + #' @rdname brmsfamily #' @export cumulative <- function(link = "logit", link_disc = "log", diff --git a/R/family-lists.R b/R/family-lists.R index 8b67a785a..a9e867bfe 100644 --- a/R/family-lists.R +++ b/R/family-lists.R @@ -138,6 +138,20 @@ ) } +.family_dirichlet_multinomial <- function() { + list( + links = "logit", + dpars = "phi", + multi_dpars = "mu", # size determined by the data + type = "int", ybounds = c(-Inf, Inf), + closed = c(NA, NA), + ad = c("weights", "subset", "trials", "index"), + specials = c("multinomial", "joint_link"), + include = "fun_dirichlet_multinomial_logit.stan", + normalized = "" + ) +} + .family_beta <- function() { list( links = c( diff --git a/R/log_lik.R b/R/log_lik.R index 13e969219..c00e19ee5 100644 --- a/R/log_lik.R +++ b/R/log_lik.R @@ -837,6 +837,15 @@ log_lik_multinomial <- function(i, prep) { log_lik_weight(out, i = i, prep = prep) } +log_lik_dirichlet_multinomial <- function(i, prep) { + stopifnot(prep$family$link == "logit") + eta <- get_Mu(prep, i = i) + eta <- insert_refcat(eta, refcat = prep$refcat) + phi <- get_dpar(prep, "phi", i = i) + out <- ddirichletmultinomial(prep$data$Y[i, ], eta = eta, phi = phi, log = TRUE) + log_lik_weight(out, i = i, prep = prep) +} + log_lik_dirichlet <- function(i, prep) { stopifnot(prep$family$link == "logit") eta <- get_Mu(prep, i = i) diff --git a/R/posterior_epred.R b/R/posterior_epred.R index 8d3b90d6e..50136dc32 100644 --- a/R/posterior_epred.R +++ b/R/posterior_epred.R @@ -376,7 +376,6 @@ posterior_epred_binomial <- function(prep) { } posterior_epred_beta_binomial <- function(prep) { - # beta part included in mu trials <- data2draws(prep$data$trials, dim_mu(prep)) prep$dpars$mu * trials } @@ -587,6 +586,12 @@ posterior_epred_multinomial <- function(prep) { out } +posterior_epred_dirichlet_multinomial <- function(prep) { + # mean of dirichlet-multinomial is equal to multinomial + # (phi only affects variance of distribution) + posterior_epred_multinomial(prep) +} + posterior_epred_dirichlet <- function(prep) { get_probs <- function(i) { eta <- insert_refcat(slice_col(eta, i), refcat = prep$refcat) diff --git a/R/posterior_predict.R b/R/posterior_predict.R index f63587727..a791b718b 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -882,6 +882,16 @@ posterior_predict_multinomial <- function(i, prep, ...) { rblapply(seq_rows(p), function(s) t(rmultinom(1, size, p[s, ]))) } +posterior_predict_dirichlet_multinomial <- function(i, prep, ...) { + eta <- get_Mu(prep, i = i) + eta <- insert_refcat(eta, refcat = prep$refcat) + phi <- get_dpar(prep, "phi", i = i) + alpha <- dcategorical(seq_len(prep$data$ncat), eta = eta) * phi + p <- rdirichlet(prep$ndraws, alpha = alpha) + size <- prep$data$trials[i] + rblapply(seq_rows(p), function(s) t(rmultinom(1, size, p[s, ]))) +} + posterior_predict_dirichlet <- function(i, prep, ...) { eta <- get_Mu(prep, i = i) eta <- insert_refcat(eta, refcat = prep$refcat) diff --git a/R/stan-likelihood.R b/R/stan-likelihood.R index f4555b64c..06d575d8a 100644 --- a/R/stan-likelihood.R +++ b/R/stan-likelihood.R @@ -777,6 +777,14 @@ stan_log_lik_multinomial <- function(bterms, ...) { sdist("multinomial_logit2", p$mu, vec = FALSE) } +stan_log_lik_dirichlet_multinomial <- function(bterms, ...) { + stopifnot(bterms$family$link == "logit") + mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu + reqn_phi <- is_pred_dpar(bterms, "phi") + phi <- stan_log_lik_dpars(bterms, reqn = reqn_phi, dpars = "phi")$phi + sdist("dirichlet_multinomial_logit2", mu, phi, vec = FALSE) +} + stan_log_lik_dirichlet <- function(bterms, ...) { stopifnot(bterms$family$link == "logit") mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu diff --git a/inst/chunks/fun_dirichlet_multinomial_logit.stan b/inst/chunks/fun_dirichlet_multinomial_logit.stan new file mode 100644 index 000000000..1195f6cb3 --- /dev/null +++ b/inst/chunks/fun_dirichlet_multinomial_logit.stan @@ -0,0 +1,19 @@ + /* dirichlet-multinomial-logit log-PMF + * Args: + * y: array of integer response values + * mu: vector of category logit probabilities + * phi: precision parameter (sum of Dirichlet alphas) + * Returns: + * a scalar to be added to the log posterior + */ + real dirichlet_multinomial_logit2_lpmf(array[] int y, vector mu, real phi) { + // get Dirichlet alphas + int N = num_elements(mu); + vector[N] alpha = phi * softmax(mu); + + // get trials from y + real T = sum(y); + + return lgamma(phi) + lgamma(T + 1.0) - lgamma(T + phi) + + sum(lgamma(to_vector(y) + alpha)) - sum(lgamma(alpha)) - sum(lgamma(to_vector(y) + 1)); + } diff --git a/man/brmsfamily.Rd b/man/brmsfamily.Rd index fbf81457b..9d6c28fd4 100644 --- a/man/brmsfamily.Rd +++ b/man/brmsfamily.Rd @@ -36,6 +36,7 @@ \alias{zero_inflated_beta_binomial} \alias{categorical} \alias{multinomial} +\alias{dirichlet_multinomial} \alias{cumulative} \alias{sratio} \alias{cratio} @@ -153,6 +154,8 @@ categorical(link = "logit", refcat = NULL) multinomial(link = "logit", refcat = NULL) +dirichlet_multinomial(link = "logit", link_phi = "log", refcat = NULL) + cumulative(link = "logit", link_disc = "log", threshold = "flexible") sratio(link = "logit", link_disc = "log", threshold = "flexible") @@ -171,7 +174,7 @@ supported: \code{gaussian}, \code{student}, \code{binomial}, \code{inverse.gaussian}, \code{exponential}, \code{weibull}, \code{frechet}, \code{Beta}, \code{dirichlet}, \code{von_mises}, \code{asym_laplace}, \code{gen_extreme_value}, \code{categorical}, -\code{multinomial}, \code{cumulative}, \code{cratio}, \code{sratio}, +\code{multinomial}, \code{dirichlet_multinomial}, \code{cumulative}, \code{cratio}, \code{sratio}, \code{acat}, \code{hurdle_poisson}, \code{hurdle_negbinomial}, \code{hurdle_gamma}, \code{hurdle_lognormal}, \code{hurdle_cumulative}, \code{zero_inflated_binomial}, \code{zero_inflated_beta_binomial}, @@ -224,8 +227,9 @@ consecutive thresholds to the same value, and \code{"sum_to_zero"} ensures the thresholds sum to zero.} \item{refcat}{Optional name of the reference response category used in -\code{categorical}, \code{multinomial}, \code{dirichlet} and -\code{logistic_normal} models. If \code{NULL} (the default), the first +\code{categorical}, \code{multinomial}, \code{dirichlet}, +\code{dirichlet_multinomial} and \code{logistic_normal} models. +If \code{NULL} (the default), the first category is used as the reference. If \code{NA}, all categories will be predicted, which requires strong priors or carefully specified predictor terms in order to lead to an identified model.} @@ -259,8 +263,9 @@ Below, we list common use cases for the different families. can be used for binary regression (i.e., most commonly logistic regression).} - \item{Families \code{categorical} and \code{multinomial} can be used for - multi-logistic regression when there are more than two possible outcomes.} + \item{Families \code{categorical}, \code{multinomial} and + \code{dirichlet_multinomial} can be used for multi-logistic regression + when there are more than two possible outcomes.} \item{Families \code{cumulative}, \code{cratio} ('continuation ratio'), \code{sratio} ('stopping ratio'), and \code{acat} ('adjacent category') @@ -333,8 +338,8 @@ Below, we list common use cases for the different families. \code{acat}, and \code{hurdle_cumulative} support \code{logit}, \code{probit}, \code{probit_approx}, \code{cloglog}, and \code{cauchit}.} - \item{Families \code{categorical}, \code{multinomial}, and \code{dirichlet} - support \code{logit}.} + \item{Families \code{categorical}, \code{multinomial}, + \code{dirichlet_multinomial} and \code{dirichlet} support \code{logit}.} \item{Families \code{Gamma}, \code{weibull}, \code{exponential}, \code{frechet}, and \code{hurdle_gamma} support diff --git a/tests/local/tests.models-4.R b/tests/local/tests.models-4.R index 412c0e322..adc567d3d 100644 --- a/tests/local/tests.models-4.R +++ b/tests/local/tests.models-4.R @@ -181,6 +181,37 @@ test_that("multinomial models work correctly", suppressWarnings({ expect_ggplot(plot(ce, ask = FALSE)[[1]]) })) +test_that("dirichlet_multinomial models work correctly", suppressWarnings({ + require("extraDistr") + set.seed(1245) + N <- 100 + dat <- as.data.frame(extraDistr::rdirmnom(N, 10, c(10, 5, 1))) + names(dat) <- paste0("y", 1:3) + dat$size <- with(dat, y1 + y2 + y3) + dat$x <- rnorm(N) + dat$y <- with(dat, cbind(y1, y2, y3)) + + fit <- brm( + y | trials(size) ~ x, data = dat, + family = dirichlet_multinomial(), + prior = prior("exponential(0.01)", "phi") + ) + print(summary(fit)) + + pred <- predict(fit) + expect_equal(dim(pred), c(nobs(fit), 4, 3)) + expect_equal(dimnames(pred)[[3]], c("y1", "y2", "y3")) + + pred_mean <- fitted(fit) + expect_equal(dim(pred_mean), c(nobs(fit), 4, 3)) + + waic <- waic(fit) + expect_range(waic$estimates[3, 1], 550, 650) + + ce <- conditional_effects(fit, categorical = TRUE) + expect_ggplot(plot(ce, ask = FALSE)[[1]]) +})) + test_that("dirichlet models work correctly", suppressWarnings({ set.seed(1246) N <- 100 diff --git a/tests/testthat/tests.log_lik.R b/tests/testthat/tests.log_lik.R index b71722110..8aa702cfd 100644 --- a/tests/testthat/tests.log_lik.R +++ b/tests/testthat/tests.log_lik.R @@ -445,6 +445,12 @@ test_that("log_lik for categorical and related models runs without erros", { ll <- sapply(1:nobs, brms:::log_lik_multinomial, prep = prep) expect_equal(dim(ll), c(ns, nobs)) + prep$data$trials <- sample(1:20, nobs) + prep$dpars$phi <- rexp(ns, 10) + prep$family <- dirichlet_multinomial() + ll <- sapply(1:nobs, brms:::log_lik_dirichlet_multinomial, prep = prep) + expect_equal(dim(ll), c(ns, nobs)) + prep$data$Y <- prep$data$Y / rowSums(prep$data$Y) prep$dpars$phi <- rexp(ns, 10) prep$family <- dirichlet() diff --git a/tests/testthat/tests.posterior_epred.R b/tests/testthat/tests.posterior_epred.R index 714b5d255..116960ff4 100644 --- a/tests/testthat/tests.posterior_epred.R +++ b/tests/testthat/tests.posterior_epred.R @@ -182,7 +182,7 @@ test_that("posterior_epred for advanced count data distributions runs without er expect_equal(dim(pred), c(ns, nobs)) }) -test_that("posterior_epred for multinomial and dirichlet models runs without errors", { +test_that("posterior_epred for multinomial, dirichlet_multinomial and dirichlet models runs without errors", { ns <- 15 nobs <- 8 ncat <- 3 @@ -198,6 +198,10 @@ test_that("posterior_epred for multinomial and dirichlet models runs without err pred <- brms:::posterior_epred_multinomial(prep = prep) expect_equal(dim(pred), c(ns, nobs, ncat)) + prep$family <- dirichlet_multinomial() + pred <- brms:::posterior_epred_dirichlet_multinomial(prep = prep) + expect_equal(dim(pred), c(ns, nobs, ncat)) + prep$family <- dirichlet() pred <- brms:::posterior_epred_dirichlet(prep = prep) expect_equal(dim(pred), c(ns, nobs, ncat)) diff --git a/tests/testthat/tests.posterior_predict.R b/tests/testthat/tests.posterior_predict.R index a496455ab..8c2bea1be 100644 --- a/tests/testthat/tests.posterior_predict.R +++ b/tests/testthat/tests.posterior_predict.R @@ -346,6 +346,12 @@ test_that("posterior_predict for categorical and related models runs without err pred <- brms:::posterior_predict_multinomial(i = sample(1:nobs, 1), prep = prep) expect_equal(dim(pred), c(ns, ncat)) + prep$data$trials <- sample(1:20, nobs) + prep$dpars$phi <- rexp(ns, 1) + prep$family <- dirichlet_multinomial() + pred <- brms:::posterior_predict_dirichlet_multinomial(i = sample(1:nobs, 1), prep = prep) + expect_equal(dim(pred), c(ns, ncat)) + prep$dpars$phi <- rexp(ns, 1) prep$family <- dirichlet() pred <- brms:::posterior_predict_dirichlet(i = sample(1:nobs, 1), prep = prep) diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index 88b221c4a..b541148e6 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -678,6 +678,29 @@ test_that("Stan code for multinomial models is correct", { expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);") }) +test_that("Stan code for dirichlet_multinomial models is correct", { + N <- 15 + dat <- data.frame( + y1 = rbinom(N, 10, 0.3), y2 = rbinom(N, 10, 0.5), + y3 = rbinom(N, 10, 0.7), x = rnorm(N) + ) + dat$size <- with(dat, y1 + y2 + y3) + dat$y <- with(dat, cbind(y1, y2, y3)) + prior <- prior(normal(0, 10), "b", dpar = muy2) + + prior(cauchy(0, 1), "Intercept", dpar = muy2) + + prior(normal(0, 2), "Intercept", dpar = muy3) + + prior(exponential(10), "phi") + scode <- stancode(bf(y | trials(size) ~ 1, muy2 ~ x), data = dat, + family = dirichlet_multinomial(), prior = prior) + expect_match2(scode, "array[N, ncat] int Y;") + expect_match2(scode, "target += dirichlet_multinomial_logit2_lpmf(Y[n] | mu[n], phi);") + expect_match2(scode, "muy2 += Intercept_muy2 + Xc_muy2 * b_muy2;") + expect_match2(scode, "lprior += normal_lpdf(b_muy2 | 0, 10);") + expect_match2(scode, "lprior += cauchy_lpdf(Intercept_muy2 | 0, 1);") + expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);") + expect_match2(scode, "lprior += exponential_lpdf(phi | 10);") +}) + test_that("Stan code for dirichlet models is correct", { N <- 15 dat <- as.data.frame(rdirichlet(N, c(3, 2, 1))) diff --git a/tests/testthat/tests.standata.R b/tests/testthat/tests.standata.R index c16e3d11f..dee92a0f5 100644 --- a/tests/testthat/tests.standata.R +++ b/tests/testthat/tests.standata.R @@ -976,7 +976,7 @@ test_that("reserved variables 'Intercept' is handled correctly", { expect_true(all(sdata$X[, "Intercept"] == 1)) }) -test_that("data for multinomial and dirichlet models is correct", { +test_that("data for multinomial, dirichlet_multinomial and dirichlet models is correct", { N <- 15 dat <- as.data.frame(rdirichlet(N, c(3, 2, 1))) names(dat) <- c("y1", "y2", "y3") @@ -993,6 +993,11 @@ test_that("data for multinomial and dirichlet models is correct", { expect_equal(sdata$ncat, 3) expect_equal(sdata$Y, unname(dat$t)) + sdata <- standata(t | trials(size) ~ x, dat, dirichlet_multinomial()) + expect_equal(sdata$trials, as.array(dat$size)) + expect_equal(sdata$ncat, 3) + expect_equal(sdata$Y, unname(dat$t)) + sdata <- standata(y ~ x, data = dat, family = dirichlet()) expect_equal(sdata$ncat, 3) expect_equal(sdata$Y, unname(dat$y)) @@ -1001,6 +1006,10 @@ test_that("data for multinomial and dirichlet models is correct", { standata(t | trials(10) ~ x, data = dat, family = multinomial()), "Number of trials does not match the number of events" ) + expect_error( + standata(t | trials(10) ~ x, data = dat, family = dirichlet_multinomial()), + "Number of trials does not match the number of events" + ) expect_error(standata(t ~ x, data = dat, family = dirichlet()), "Response values in simplex models must sum to 1") })