Skip to content

Commit

Permalink
Add support for constructing target distribution from formula (#66)
Browse files Browse the repository at this point in the history
* Add function for constructing target distribution from formula

* Wrap trace function into target distribution arg

* Ensure gradient a vector and add to NAMESPACE

* Combine BridgeStan interface functions for constructing target distribution and trace function

* Allow passing formula or Stan model directly to sample_chain

* Use dummy variable declaration to avoid check note

* Qualify deriv call with stats package name

* Test target distribution from formula function

* Use base::inherits in place of methods::is

* Remove removed trace_function argument from sample_chain docs

Clarify trace_function  allowable output types

* Test using invalid target distribution with sample_chain raises error

* Test passing explicit trace function to sample_chain

* Test using sample_chain with Stan model and log density formula works
  • Loading branch information
matt-graham authored Jan 6, 2025
1 parent 95acd2d commit 7f87e4b
Show file tree
Hide file tree
Showing 13 changed files with 423 additions and 249 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export(sample_chain)
export(scale_adapter)
export(shape_adapter)
export(stochastic_approximation_scale_adapter)
export(target_distribution_from_log_density_formula)
export(target_distribution_from_stan_model)
export(trace_function_from_stan_model)
export(variance_shape_adapter)
119 changes: 77 additions & 42 deletions R/bridges.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
#' Construct target distribution from a BridgeStan `StanModel` object.
#'
#' @param model Stan model object to use for target (posterior) distribution.
#' @param include_log_density Whether to include an entry `log_density`
#' corresponding to current log density for target distribution in values
#' returned by trace function.
#' @param include_generated_quantities Whether to included generated quantities
#' in Stan model definition in values returned by trace function.
#' @param include_transformed_parameters Whether to include transformed
#' parameters in Stan model definition in values returned by trace function.
#'
#' @return A list with entries
#' * `log_density`: A function to evaluate log density function for target
#' distribution given current position vector.
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
#' of log density function for target distribution given current position
#' vector, returning as a list with entries `value` and `gradient`.
#' * `trace_function`: A function which given a `chain_state()` object returns a
#' named vector of values to trace during sampling. The constrained parameter
#' values of model will always be included.
#'
#' @export
#'
Expand All @@ -18,46 +28,13 @@
#' 876287L, state <- chain_state(stats::rnorm(model$param_unc_num()))
#' )
#' state$log_density(target_distribution)
target_distribution_from_stan_model <- function(model) {
list(
log_density = model$log_density,
value_and_gradient_log_density = function(position) {
value_and_gradient <- model$log_density_gradient(position)
names(value_and_gradient) <- c("value", "gradient")
value_and_gradient
}
)
}

#' Construct trace function from a BridgeStan `StanModel` object.
#'
#' @param model Stan model object to use to generate (constrained) parameters to
#' trace.
#' @param include_log_density Whether to include an entry `log_density`
#' corresponding to current log density for target distribution in values
#' returned by trace function.
#' @param include_generated_quantities Whether to included generated quantities
#' in Stan model definition in values returned by trace function.
#' @param include_transformed_parameters Whether to include transformed
#' parameters in Stan model definition in values returned by trace function.
#'
#' @return A function which given `chain_state` object returns a named vector of
#' values to trace during sampling. The constrained parameter values of model
#' will always be included.
#'
#' @export
#'
#' @examplesIf requireNamespace("bridgestan", quietly = TRUE)
#' model <- example_gaussian_stan_model()
#' trace_function <- trace_function_from_stan_model(model)
#' withr::with_seed(876287L, state <- chain_state(rnorm(model$param_unc_num())))
#' trace_function(state)
trace_function_from_stan_model <- function(
#' target_distribution$trace_function(state)
target_distribution_from_stan_model <- function(
model,
include_log_density = TRUE,
include_generated_quantities = FALSE,
include_transformed_parameters = FALSE) {
function(state) {
trace_function <- function(state) {
position <- state$position()
trace_values <- model$param_constrain(
position, include_transformed_parameters, include_generated_quantities
Expand All @@ -68,15 +45,25 @@ trace_function_from_stan_model <- function(
}
trace_values
}
list(
log_density = model$log_density,
value_and_gradient_log_density = function(position) {
value_and_gradient <- model$log_density_gradient(position)
names(value_and_gradient) <- c("value", "gradient")
value_and_gradient
},
trace_function = trace_function
)
}

#' Construct an example BridgeStan `StanModel` object for a Gaussian model.
#'
#' Requires BridgeStan package to be installed. Generative model is assumed to
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu` and `sigma`.
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu ~ normal(0, 3)` and
#' `sigma ~ half_normal(0, 3)`.
#'
#' @param n_data Number of independent data points `y` to generate and condition
#' model against.
#' model against from `normal(0, 1)`.
#' @param seed Integer seed for Stan model.
#'
#' @return BridgeStan StanModel object.
Expand All @@ -88,20 +75,23 @@ trace_function_from_stan_model <- function(
#' model$param_names()
example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
rlang::check_installed("bridgestan", reason = "to use this function")
model_string <- "data {
int<lower=0> N;
vector[N] y;
model_string <- "
data {
int<lower=0> N;
vector[N] y;
}
parameters {
real mu;
real<lower=0> sigma;
}
model {
mu ~ normal(0, 3);
sigma ~ normal(0, 3);
y ~ normal(mu, sigma);
}"
withr::with_seed(seed, y <- stats::rnorm(n_data))
data_string <- sprintf('{"N": %i, "y": [%s]}', n_data, toString(y))
model_file <- tempfile("gaussian", fileext = ".stan")
model_file <- NULL # to avoid 'no visible binding for global variable' note
withr::with_tempfile("model_file",
{
writeLines(model_string, model_file)
Expand All @@ -111,3 +101,48 @@ example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
fileext = ".stan"
)
}

#' Construct target distribution from a formula specifying log density.
#'
#' @param log_density_formula Formula for which right-hand side specifies
#' expression for logarithm of (unnormalized) density of target distribution.
#'
#' @return A list with entries
#' * `log_density`: A function to evaluate log density function for target
#' distribution given current position vector.
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
#' of log density function for target distribution given current position
#' vector, returning as a list with entries `value` and `gradient`.
#'
#' @export
#'
#' @examples
#' target_distribution <- target_distribution_from_log_density_formula(
#' ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 10)
#' )
#' target_distribution$value_and_gradient_log_density(c(0.1, -0.3))
target_distribution_from_log_density_formula <- function(log_density_formula) {
variables <- all.vars(log_density_formula)
deriv_log_density <- stats::deriv(log_density_formula, variables, func = TRUE)
value_and_gradient_log_density <- function(position) {
names(position) <- variables
value <- rlang::inject(deriv_log_density(!!!position))
gradient <- drop(attr(value, "gradient"))
attr(value, "gradient") <- NULL
list(value = value, gradient = gradient)
}
log_density <- function(position) {
value_and_gradient_log_density(position)$value
}
trace_function <- function(state) {
trace_values <- state$position()
names(trace_values) <- variables
trace_values["log_density"] <- log_density(state$position())
trace_values
}
list(
log_density = log_density,
value_and_gradient_log_density = value_and_gradient_log_density,
trace_function = trace_function
)
}
47 changes: 42 additions & 5 deletions R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@
#' target distribution and proposal (defaulting to Barker proposal), optionally
#' adapting proposal parameters in a warm-up stage.
#'
#' @inheritParams sample_metropolis_hastings
#' @param target_distribution Target stationary distribution for chain. One of:
#' * A one-sided formula specifying expression for log density of target
#' distribution which will be passed to
#' [target_distribution_from_log_density_formula()] to construct functions
#' to evaluate log density and its gradient using [deriv()].
#' * A `bridgestan::StanModel` instance (requires `bridgestan` to be
#' installed) specifying target model and data. Will be passed to
#' [target_distribution_from_stan_model()] using default values for optional
#' arguments - to override call [target_distribution_from_stan_model()]
#' directly and pass the returned list as the `target_distribution` argument
#' here.
#' * A list with named entries `log_density` and `gradient_log_density`
#' corresponding to respectively functions for evaluating the logarithm of
#' the (potentially unnormalized) density of the target distribution and its
#' gradient (only required for gradient-based proposals). As an alternative
#' to `gradient_log_density` an entry `value_and_gradient_log_density` may
#' instead be provided which is a function returning both the value and
#' gradient of the logarithm of the (unnormalized) density of the target
#' distribution as a list under the names `value` and `gradient`
#' respectively. The list may also contain a named entry `trace_function`,
#' correspond to a function which given current chain state outputs a named
#' vector or list of variables to trace on each main (non-adaptive) chain
#' iteration. If a `trace_function` entry is not specified, then the default
#' behaviour is to trace the position component of the chain state along
#' with the log density of the target distribution.
#' @param initial_state Initial chain state. Either a vector specifying just
#' the position component of the chain state or a list output by `chain_state`
#' specifying the full chain state.
Expand Down Expand Up @@ -32,8 +56,6 @@
#' coerce the average acceptance rate to a target value using a dual-averaging
#' algorithm, and adapting the shape to an estimate of the covariance of the
#' target distribution.
#' @param trace_function Function which given current chain state outputs list
#' of variables to trace on each main (non-adaptive) chain iteration.
#' @param show_progress_bar Whether to show progress bars during sampling.
#' Requires `progress` package to be installed to have an effect.
#' @param trace_warm_up Whether to record chain traces and adaptation /
Expand Down Expand Up @@ -78,7 +100,6 @@ sample_chain <- function(
n_main_iteration,
proposal = barker_proposal(),
adapters = list(scale_adapter(), shape_adapter()),
trace_function = NULL,
show_progress_bar = TRUE,
trace_warm_up = FALSE) {
progress_available <- requireNamespace("progress", quietly = TRUE)
Expand All @@ -90,8 +111,24 @@ sample_chain <- function(
} else {
stop("initial_state must be a vector or list with an entry named position.")
}
if (is.null(trace_function)) {
if (inherits(target_distribution, "formula")) {
target_distribution <- target_distribution_from_log_density_formula(
target_distribution
)
} else if (inherits(target_distribution, "StanModel")) {
target_distribution <- target_distribution_from_stan_model(
target_distribution
)
} else if (
!is.list(target_distribution) ||
!("log_density" %in% names(target_distribution))
) {
stop("target_distribution invalid - see documentation for allowable types.")
}
if (is.null(target_distribution$trace_function)) {
trace_function <- default_trace_function(target_distribution)
} else {
trace_function <- target_distribution$trace_function
}
statistic_names <- list("accept_prob")
warm_up_results <- chain_loop(
Expand Down
26 changes: 6 additions & 20 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,22 @@ cat(
```

As a second example, the snippet below demonstrates sampling from a two-dimensional banana shaped distribution based on the [Rosenbrock function](https://en.wikipedia.org/wiki/Rosenbrock_function) and plotting the generated chain samples.
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain`,
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain()`,
corresponding respectively to the Barker proposal, and adapters for tuning the proposal scale to coerce the average acceptance rate using a dual-averaging algorithm,
and for tuning the proposal shape based on an estimate of the target distribution covariance matrix.
The `target_distribution` argument to `sample_chain()` is passed a formula specifying the log density of the target distribution, which is passed to `target_distribution_from_log_density_formula()` to construct necessary functions,
using `stats::deriv()` to symbolically compute derivatives.


```{r banana-samples, fig.width=6, fig.height=4}
library(rmcmc)
set.seed(651239L)
target_distribution <- list(
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
gradient_log_density = function(x) {
c(
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
)
}
)
results <- sample_chain(
target_distribution = target_distribution,
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
initial_state = rnorm(2),
n_warm_up_iteration = 10000,
n_main_iteration = 10000,
)
plot(
results$traces[, "position1"],
results$traces[, "position2"],
xlab = expression(x[1]),
ylab = expression(x[2]),
col = "#1f77b4",
pch = 20
n_main_iteration = 10000
)
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
```
35 changes: 12 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,28 @@ As a second example, the snippet below demonstrates sampling from a
two-dimensional banana shaped distribution based on the [Rosenbrock
function](https://en.wikipedia.org/wiki/Rosenbrock_function) and
plotting the generated chain samples. Here we use the default values of
the `proposal` and `adapters` arguments to `sample_chain`, corresponding
respectively to the Barker proposal, and adapters for tuning the
proposal scale to coerce the average acceptance rate using a
the `proposal` and `adapters` arguments to `sample_chain()`,
corresponding respectively to the Barker proposal, and adapters for
tuning the proposal scale to coerce the average acceptance rate using a
dual-averaging algorithm, and for tuning the proposal shape based on an
estimate of the target distribution covariance matrix.
estimate of the target distribution covariance matrix. The
`target_distribution` argument to `sample_chain()` is passed a formula
specifying the log density of the target distribution, which is passed
to `target_distribution_from_log_density_formula()` to construct
necessary functions, using `stats::deriv()` to symbolically compute
derivatives.

``` r
library(rmcmc)

set.seed(651239L)
target_distribution <- list(
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
gradient_log_density = function(x) {
c(
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
)
}
)
results <- sample_chain(
target_distribution = target_distribution,
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
initial_state = rnorm(2),
n_warm_up_iteration = 10000,
n_main_iteration = 10000,
)
plot(
results$traces[, "position1"],
results$traces[, "position2"],
xlab = expression(x[1]),
ylab = expression(x[2]),
col = "#1f77b4",
pch = 20
n_main_iteration = 10000
)
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
```

<img src="man/figures/README-banana-samples-1.png" width="100%" />
5 changes: 3 additions & 2 deletions man/example_gaussian_stan_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified man/figures/README-banana-samples-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 7f87e4b

Please sign in to comment.