Skip to content

Commit

Permalink
Merge pull request #98 from venpopov/avoid-nan-transformed-dists
Browse files Browse the repository at this point in the history
Avoid NaN in  transformed dists density and cdf
  • Loading branch information
mitchelloharawild authored Apr 2, 2024
2 parents 5946354 + 7ab2085 commit f09caec
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 16 deletions.
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ Authors@R:
person(given = "Earo",
family = "Wang",
role = c("ctb"),
comment = c(ORCID = "0000-0001-6448-5260")))
comment = c(ORCID = "0000-0001-6448-5260")),
person(given = "Vencislav",
family = "Popov",
role = c("ctb"),
comment = c(ORCID = "0000-0002-8073-4199")))
Description: Vectorised distribution objects with tools for manipulating,
visualising, and using probability distributions. Designed to allow model
prediction outputs to return distributions rather than their parameters,
Expand Down Expand Up @@ -49,4 +53,4 @@ BugReports: https://github.com/mitchelloharawild/distributional/issues
Encoding: UTF-8
Language: en-GB
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ S3method(skewness,distribution)
S3method(sum,distribution)
S3method(support,dist_categorical)
S3method(support,dist_default)
S3method(support,dist_transformed)
S3method(support,distribution)
S3method(variance,default)
S3method(variance,dist_default)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

* Fixed error when using '-' as a unary operator on a distribution different from
`dist_normal()` by @venpopov (#95)
* Density for transformed distributions now correctly gives 0 instead of NaNs for
values outside the support of the distribution (#97); by @venpopov

## New features

* support() now shows whether the interval of support is open or closed (#97); by @venpopov

# distributional 0.4.0

Expand Down
5 changes: 4 additions & 1 deletion R/default.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ family.dist_default <- function(object, ...) {

#' @export
support.dist_default <- function(x, ...) {
qs <- quantile(x, c(0, 1))
ds <- density(x, qs)
new_support_region(
list(vctrs::vec_init(generate(x, 1), n = 0L)),
list(quantile(x, c(0, 1)))
list(qs),
list(!near(ds, 0))
)
}

Expand Down
27 changes: 17 additions & 10 deletions R/support.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#'
#' @param x A list of prototype vectors defining the distribution type.
#' @param limits A list of value limits for the distribution.
#' @param closed A list of logical(2L) indicating whether the limits are closed.
#'
new_support_region <- function(x, limits = NULL) {
vctrs::new_rcrd(list(x = x, lim = limits), class = "support_region")
new_support_region <- function(x = numeric(), limits = list(), closed = list()) {
vctrs::new_rcrd(list(x = x, lim = limits, closed = closed), class = "support_region")
}

#' @export
format.support_region <- function(x, ...) {
format.support_region <- function(x, digits = 3, ...) {
type <- vapply(field(x, "x"), function(z) {
out <- if(is.integer(z)) "Z"
else if(is.numeric(z)) "R"
Expand All @@ -21,15 +22,21 @@ format.support_region <- function(x, ...) {
}
out
}, FUN.VALUE = character(1L))
mapply(function(type, z) {
if(any(is.na(z)) || all(is.infinite(z))) type
else if (type == "Z" && identical(z[2], Inf)) {
if(z[1] == 0L) "N0" else if (z[2] == 1L) "N+" else paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
brackets <- list(c("(","["), c(")","]"))
mapply(function(type, z, closed) {
br1 <- brackets[[1]][closed[1] + 1L]
br2 <- brackets[[2]][closed[2] + 1L]
fz <- sapply(z, function(x) format(x, digits = digits))
fz <- gsub("3.14", "pi", fz, fixed = TRUE)
if (any(is.na(z)) || all(is.infinite(z))) type
else if (type == "Z") {
if (identical(z, c(0L, Inf))) "N0"
else if (identical(z, c(1L, Inf))) "N+"
else paste0("{", z[1], ",", z[1]+1L, ",...,", z[2], "}")
}
else if (type == "R") paste0("[", z[1], ",", z[2], "]")
else if (type == "Z") paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
else if (type == "R") paste0(br1, fz[1], ",", fz[2], br2)
else type
}, type, field(x, "lim"))
}, type, field(x, "lim"), field(x, "closed"))
}

#' @export
Expand Down
33 changes: 31 additions & 2 deletions R/transformed.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,43 @@ format.dist_transformed <- function(x, ...){
)
}

#' @export
support.dist_transformed <- function(x, ...) {
support <- support(x[["dist"]])
lim <- field(support, "lim")[[1]]
lim <- suppressWarnings(x[['transform']](lim))
if (all(!is.na(lim))) {
lim <- sort(lim)
}
field(support, "lim")[[1]] <- lim
support
}

#' @export
density.dist_transformed <- function(x, at, ...){
density(x[["dist"]], x[["inverse"]](at))*abs(vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]]))
inv <- function(v) suppressWarnings(x[["inverse"]](v))
jacobian <- vapply(at, numDeriv::jacobian, numeric(1L), func = inv)
d <- density(x[["dist"]], inv(at)) * abs(jacobian)
limits <- field(support(x), "lim")[[1]]
closed <- field(support(x), "closed")[[1]]
if (!any(is.na(limits))) {
`%less_than%` <- if (closed[1]) `<` else `<=`
`%greater_than%` <- if (closed[2]) `>` else `>=`
d[which(at %less_than% limits[1] | at %greater_than% limits[2])] <- 0
}
d
}

#' @export
cdf.dist_transformed <- function(x, q, ...){
cdf(x[["dist"]], x[["inverse"]](q), ...)
inv <- function(v) suppressWarnings(x[["inverse"]](v))
p <- cdf(x[["dist"]], inv(q), ...)
limits <- field(support(x), "lim")[[1]]
if (!any(is.na(limits))) {
p[q <= limits[1]] <- 0
p[q >= limits[2]] <- 1
}
p
}

#' @export
Expand Down
5 changes: 5 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,8 @@ restore_rng <- function(expr, seed = NULL) {

expr
}

near <- function(x, y) {
tol <- .Machine$double.eps^0.5
abs(x - y) < tol
}
1 change: 1 addition & 0 deletions man/distributional-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/new_support_region.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions tests/testthat/test-support.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_that("support gives the correct bounds", {
s <- support(c(dist_normal(),
dist_gamma(1, 1),
dist_gamma(2, 1),
dist_lognormal(),
dist_beta(1, 1),
dist_beta(1, 2),
dist_beta(2, 1),
dist_beta(2, 2),
exp(dist_wrap('norm')),
2*atan(dist_normal())))
out <- unname(format(s))
expect_equal(out, c("R","[0,Inf)","(0,Inf)","(0,Inf)","[0,1]","[0,1)","(0,1]","(0,1)","(0,Inf)","(-pi,pi)"))
})
50 changes: 50 additions & 0 deletions tests/testthat/test-transformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ test_that("inverses are applied automatically", {

})

test_that("transformed distributions' density is 0 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), -1)[[1]], 0)

dist <- dist_wrap('gamma', shape = 1, rate = 1)
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), 1)[[1]], 1)
})


test_that("transformed distributions' cdf is 0/1 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(cdf(exp(dist), 0)[[1]], 0)
expect_equal(cdf(exp(dist), -1)[[1]], 0)
expect_equal(cdf(-1*exp(dist), 0)[[1]], 1)
expect_equal(cdf(-1*exp(dist), 2)[[1]], 1)
})

test_that("unary negation operator works", {
dist <- dist_normal(1,1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
Expand All @@ -129,3 +148,34 @@ test_that("unary negation operator works", {
dist <- dist_student_t(3, mu = 1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
})

test_that("transformed distributions pdf integrates to 1", {
dist_names <- c('norm', 'gamma', 'beta', 'chisq', 'exp',
'logis', 't', 'unif', 'weibull')
dist_args <- list(list(mean = 1, sd = 1), list(shape = 2, rate = 1),
list(shape1 = 3, shape2 = 5), list(df = 5),
list(rate = 1),
list(location = 1.5, scale = 1), list(df = 10),
list(min = 0, max = 1), list(shape = 3, scale = 1))
names(dist_args) <- dist_names
dist <- lapply(dist_names, function(x) do.call(dist_wrap, c(x, dist_args[[x]])))
dist <- do.call(c, dist)
dfun <- function(x, id, transform) density(get(transform)(dist[id]), x)[[1]]
twoexp <- function(x) 2^x
square <- function(x) x^2
mult2 <- function(x) 2*x
identity <- function(x) x
tol <- 1e-5
for (i in 1:length(dist)) {
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'identity')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'exp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'twoexp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'mult2')$value, 1, tolerance = tol)
lower_bound <- field(support(dist[[i]]), "lim")[[1]][1]
if (near(lower_bound, 0)) {
expect_equal(integrate(dfun, -Inf, 5, id = i, transform = 'log')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'square')$value, 1, tolerance = tol)
}
}
})

0 comments on commit f09caec

Please sign in to comment.