Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN in transformed dists density and cdf #98

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ Authors@R:
person(given = "Earo",
family = "Wang",
role = c("ctb"),
comment = c(ORCID = "0000-0001-6448-5260")))
comment = c(ORCID = "0000-0001-6448-5260")),
person(given = "Vencislav",
family = "Popov",
role = c("ctb"),
comment = c(ORCID = "0000-0002-8073-4199")))
Description: Vectorised distribution objects with tools for manipulating,
visualising, and using probability distributions. Designed to allow model
prediction outputs to return distributions rather than their parameters,
Expand Down Expand Up @@ -49,4 +53,4 @@ BugReports: https://github.com/mitchelloharawild/distributional/issues
Encoding: UTF-8
Language: en-GB
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ S3method(skewness,distribution)
S3method(sum,distribution)
S3method(support,dist_categorical)
S3method(support,dist_default)
S3method(support,dist_transformed)
S3method(support,distribution)
S3method(variance,default)
S3method(variance,dist_default)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

* Fixed error when using '-' as a unary operator on a distribution different from
`dist_normal()` by @venpopov (#95)
* Density for transformed distributions now correctly gives 0 instead of NaNs for
values outside the support of the distribution (#97); by @venpopov

## New features

* support() now shows whether the interval of support is open or closed (#97); by @venpopov

# distributional 0.4.0

Expand Down
5 changes: 4 additions & 1 deletion R/default.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ family.dist_default <- function(object, ...) {

#' @export
support.dist_default <- function(x, ...) {
qs <- quantile(x, c(0, 1))
ds <- density(x, qs)
new_support_region(
list(vctrs::vec_init(generate(x, 1), n = 0L)),
list(quantile(x, c(0, 1)))
list(qs),
list(!near(ds, 0))
)
}

Expand Down
27 changes: 17 additions & 10 deletions R/support.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#'
#' @param x A list of prototype vectors defining the distribution type.
#' @param limits A list of value limits for the distribution.
#' @param closed A list of logical(2L) indicating whether the limits are closed.
#'
new_support_region <- function(x, limits = NULL) {
vctrs::new_rcrd(list(x = x, lim = limits), class = "support_region")
new_support_region <- function(x = numeric(), limits = list(), closed = list()) {
vctrs::new_rcrd(list(x = x, lim = limits, closed = closed), class = "support_region")
}

#' @export
format.support_region <- function(x, ...) {
format.support_region <- function(x, digits = 3, ...) {
type <- vapply(field(x, "x"), function(z) {
out <- if(is.integer(z)) "Z"
else if(is.numeric(z)) "R"
Expand All @@ -21,15 +22,21 @@ format.support_region <- function(x, ...) {
}
out
}, FUN.VALUE = character(1L))
mapply(function(type, z) {
if(any(is.na(z)) || all(is.infinite(z))) type
else if (type == "Z" && identical(z[2], Inf)) {
if(z[1] == 0L) "N0" else if (z[2] == 1L) "N+" else paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
brackets <- list(c("(","["), c(")","]"))
mapply(function(type, z, closed) {
br1 <- brackets[[1]][closed[1] + 1L]
br2 <- brackets[[2]][closed[2] + 1L]
fz <- sapply(z, function(x) format(x, digits = digits))
fz <- gsub("3.14", "pi", fz, fixed = TRUE)
if (any(is.na(z)) || all(is.infinite(z))) type
else if (type == "Z") {
if (identical(z, c(0L, Inf))) "N0"
else if (identical(z, c(1L, Inf))) "N+"
else paste0("{", z[1], ",", z[1]+1L, ",...,", z[2], "}")
}
else if (type == "R") paste0("[", z[1], ",", z[2], "]")
else if (type == "Z") paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
else if (type == "R") paste0(br1, fz[1], ",", fz[2], br2)
else type
}, type, field(x, "lim"))
}, type, field(x, "lim"), field(x, "closed"))
}

#' @export
Expand Down
33 changes: 31 additions & 2 deletions R/transformed.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,43 @@ format.dist_transformed <- function(x, ...){
)
}

#' @export
support.dist_transformed <- function(x, ...) {
support <- support(x[["dist"]])
lim <- field(support, "lim")[[1]]
lim <- suppressWarnings(x[['transform']](lim))
if (all(!is.na(lim))) {
lim <- sort(lim)
}
field(support, "lim")[[1]] <- lim
support
}

#' @export
density.dist_transformed <- function(x, at, ...){
density(x[["dist"]], x[["inverse"]](at))*abs(vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]]))
inv <- function(v) suppressWarnings(x[["inverse"]](v))
jacobian <- vapply(at, numDeriv::jacobian, numeric(1L), func = inv)
d <- density(x[["dist"]], inv(at)) * abs(jacobian)
limits <- field(support(x), "lim")[[1]]
closed <- field(support(x), "closed")[[1]]
if (!any(is.na(limits))) {
`%less_than%` <- if (closed[1]) `<` else `<=`
`%greater_than%` <- if (closed[2]) `>` else `>=`
d[which(at %less_than% limits[1] | at %greater_than% limits[2])] <- 0
}
d
}

#' @export
cdf.dist_transformed <- function(x, q, ...){
cdf(x[["dist"]], x[["inverse"]](q), ...)
inv <- function(v) suppressWarnings(x[["inverse"]](v))
p <- cdf(x[["dist"]], inv(q), ...)
limits <- field(support(x), "lim")[[1]]
if (!any(is.na(limits))) {
p[q <= limits[1]] <- 0
p[q >= limits[2]] <- 1
}
p
}

#' @export
Expand Down
5 changes: 5 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,8 @@ restore_rng <- function(expr, seed = NULL) {

expr
}

near <- function(x, y) {
tol <- .Machine$double.eps^0.5
abs(x - y) < tol
}
1 change: 1 addition & 0 deletions man/distributional-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/new_support_region.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions tests/testthat/test-support.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_that("support gives the correct bounds", {
s <- support(c(dist_normal(),
dist_gamma(1, 1),
dist_gamma(2, 1),
dist_lognormal(),
dist_beta(1, 1),
dist_beta(1, 2),
dist_beta(2, 1),
dist_beta(2, 2),
exp(dist_wrap('norm')),
2*atan(dist_normal())))
out <- unname(format(s))
expect_equal(out, c("R","[0,Inf)","(0,Inf)","(0,Inf)","[0,1]","[0,1)","(0,1]","(0,1)","(0,Inf)","(-pi,pi)"))
})
50 changes: 50 additions & 0 deletions tests/testthat/test-transformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ test_that("inverses are applied automatically", {

})

test_that("transformed distributions' density is 0 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), -1)[[1]], 0)

dist <- dist_wrap('gamma', shape = 1, rate = 1)
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), 1)[[1]], 1)
})


test_that("transformed distributions' cdf is 0/1 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(cdf(exp(dist), 0)[[1]], 0)
expect_equal(cdf(exp(dist), -1)[[1]], 0)
expect_equal(cdf(-1*exp(dist), 0)[[1]], 1)
expect_equal(cdf(-1*exp(dist), 2)[[1]], 1)
})

test_that("unary negation operator works", {
dist <- dist_normal(1,1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
Expand All @@ -129,3 +148,34 @@ test_that("unary negation operator works", {
dist <- dist_student_t(3, mu = 1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
})

test_that("transformed distributions pdf integrates to 1", {
dist_names <- c('norm', 'gamma', 'beta', 'chisq', 'exp',
'logis', 't', 'unif', 'weibull')
dist_args <- list(list(mean = 1, sd = 1), list(shape = 2, rate = 1),
list(shape1 = 3, shape2 = 5), list(df = 5),
list(rate = 1),
list(location = 1.5, scale = 1), list(df = 10),
list(min = 0, max = 1), list(shape = 3, scale = 1))
names(dist_args) <- dist_names
dist <- lapply(dist_names, function(x) do.call(dist_wrap, c(x, dist_args[[x]])))
dist <- do.call(c, dist)
dfun <- function(x, id, transform) density(get(transform)(dist[id]), x)[[1]]
twoexp <- function(x) 2^x
square <- function(x) x^2
mult2 <- function(x) 2*x
identity <- function(x) x
tol <- 1e-5
for (i in 1:length(dist)) {
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'identity')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'exp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'twoexp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'mult2')$value, 1, tolerance = tol)
lower_bound <- field(support(dist[[i]]), "lim")[[1]][1]
if (near(lower_bound, 0)) {
expect_equal(integrate(dfun, -Inf, 5, id = i, transform = 'log')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'square')$value, 1, tolerance = tol)
}
}
})

Loading