Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN in transformed dists density and cdf #98

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ BugReports: https://github.com/mitchelloharawild/distributional/issues
Encoding: UTF-8
Language: en-GB
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ S3method(skewness,distribution)
S3method(sum,distribution)
S3method(support,dist_categorical)
S3method(support,dist_default)
S3method(support,dist_transformed)
S3method(support,distribution)
S3method(variance,default)
S3method(variance,dist_default)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

* Fixed error when using '-' as a unary operator on a distribution different from
`dist_normal()` by @venpopov (#95)
* Density for transformed distributions now correctly gives 0 instead of NaNs for
values outside the support of the distribution (#97); by @venpopov

## New features

* support() now shows whether the interval of support is open or closed (#97); by @venpopov

# distributional 0.4.0

Expand Down
5 changes: 4 additions & 1 deletion R/default.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ family.dist_default <- function(object, ...) {

#' @export
support.dist_default <- function(x, ...) {
qs <- quantile(x, c(0, 1))
ds <- density(x, qs)
new_support_region(
list(vctrs::vec_init(generate(x, 1), n = 0L)),
list(quantile(x, c(0, 1)))
list(qs),
list(ifelse(near(ds, 0),"open","closed"))
)
}

Expand Down
21 changes: 13 additions & 8 deletions R/support.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#'
#' @param x A list of prototype vectors defining the distribution type.
#' @param limits A list of value limits for the distribution.
#' @param interval A list of interval types for the distribution.
#'
new_support_region <- function(x, limits = NULL) {
vctrs::new_rcrd(list(x = x, lim = limits), class = "support_region")
new_support_region <- function(x, limits = NULL, interval = list(c('closed','closed'))) {
venpopov marked this conversation as resolved.
Show resolved Hide resolved
vctrs::new_rcrd(list(x = x, lim = limits, interval = interval), class = "support_region")
}

#' @export
format.support_region <- function(x, ...) {
format.support_region <- function(x, digits = 3, ...) {
type <- vapply(field(x, "x"), function(z) {
out <- if(is.integer(z)) "Z"
else if(is.numeric(z)) "R"
Expand All @@ -21,15 +22,19 @@ format.support_region <- function(x, ...) {
}
out
}, FUN.VALUE = character(1L))
mapply(function(type, z) {
mapply(function(type, z, i) {
br1 <- switch(i[1], open = "(", closed = "[")
br2 <- switch(i[2], open = ")", closed = "]")
fz <- sapply(z, function(x) format(x, digits = digits))
fz <- gsub("3.14", "pi", fz, fixed = TRUE)
if(any(is.na(z)) || all(is.infinite(z))) type
else if (type == "Z" && identical(z[2], Inf)) {
if(z[1] == 0L) "N0" else if (z[2] == 1L) "N+" else paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
if(z[1] == 0L) "N0" else if (z[1] == 1L) "N+" else paste0(br1, z[1], ",", z[1]+1L, ",...,", z[2], br2)
}
else if (type == "R") paste0("[", z[1], ",", z[2], "]")
else if (type == "Z") paste0("[", z[1], ",", z[1]+1L, ",...,", z[2], "]")
else if (type == "R") paste0(br1, fz[1], ",", fz[2], br2)
else if (type == "Z") paste0(br1, z[1], ",", z[1]+1L, ",...,", z[2], br2)
else type
}, type, field(x, "lim"))
}, type, field(x, "lim"), field(x, "interval"))
}

#' @export
Expand Down
36 changes: 34 additions & 2 deletions R/transformed.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,46 @@ format.dist_transformed <- function(x, ...){
)
}

#' @export
support.dist_transformed <- function(x, ...) {
source_supp <- vec_data(support(x[["dist"]]))
new_support_region(
list(vctrs::vec_init(generate(x, 1), n = 0L)),
list(suppressWarnings(sort(x[['transform']](source_supp$lim[[1]])))),
venpopov marked this conversation as resolved.
Show resolved Hide resolved
list(source_supp$interval[[1]])
venpopov marked this conversation as resolved.
Show resolved Hide resolved
)
}

#' @export
density.dist_transformed <- function(x, at, ...){
density(x[["dist"]], x[["inverse"]](at))*abs(vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]]))
inv <- function(v) suppressWarnings(x[["inverse"]](v))
jacobian <- vapply(at, numDeriv::jacobian, numeric(1L), func = inv)
d <- density(x[["dist"]], inv(at)) * abs(jacobian)
supp <- vec_data(support(x))
venpopov marked this conversation as resolved.
Show resolved Hide resolved
limits <- supp$lim[[1]]
interval <- supp$interval[[1]]
if (interval[1] == "closed") {
d[which(at < limits[1])] <- 0
} else {
d[which(at <= limits[1])] <- 0
}
if (interval[2] == "closed") {
d[which(at > limits[2])] <- 0
} else {
d[which(at >= limits[2])] <- 0
}
d
}

#' @export
cdf.dist_transformed <- function(x, q, ...){
cdf(x[["dist"]], x[["inverse"]](q), ...)
inv <- function(v) suppressWarnings(x[["inverse"]](v))
p <- cdf(x[["dist"]], inv(q), ...)
supp <- vec_data(support(x))
limits <- supp$lim[[1]]
p[which(q <= limits[1])] <- 0
p[which(q >= limits[2])] <- 1
p
}

#' @export
Expand Down
5 changes: 5 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,8 @@ restore_rng <- function(expr, seed = NULL) {

expr
}

near <- function(x, y) {
tol <- .Machine$double.eps^0.5
abs(x - y) < tol
}
4 changes: 3 additions & 1 deletion man/new_support_region.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions tests/testthat/test-support.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_that("support gives the correct bounds", {
s <- support(c(dist_normal(),
dist_gamma(1, 1),
dist_gamma(2, 1),
dist_lognormal(),
dist_beta(1, 1),
dist_beta(1, 2),
dist_beta(2, 1),
dist_beta(2, 2),
exp(dist_wrap('norm')),
2*atan(dist_normal())))
out <- unname(format(s))
expect_equal(out, c("R","[0,Inf)","(0,Inf)","(0,Inf)","[0,1]","[0,1)","(0,1]","(0,1)","(0,Inf)","(-pi,pi)"))
})
49 changes: 49 additions & 0 deletions tests/testthat/test-transformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ test_that("inverses are applied automatically", {

})

test_that("transformed distributions' density is 0 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), -1)[[1]], 0)

dist <- dist_wrap('gamma', shape = 1, rate = 1)
expect_equal(density(exp(dist), 0)[[1]], 0)
expect_equal(density(exp(dist), 1)[[1]], 1)
})


test_that("transformed distributions' cdf is 0/1 outside of the support region", {
dist <- dist_wrap('norm')
expect_equal(cdf(exp(dist), 0)[[1]], 0)
expect_equal(cdf(exp(dist), -1)[[1]], 0)
expect_equal(cdf(-1*exp(dist), 0)[[1]], 1)
expect_equal(cdf(-1*exp(dist), 2)[[1]], 1)
})

test_that("unary negation operator works", {
dist <- dist_normal(1,1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
Expand All @@ -129,3 +148,33 @@ test_that("unary negation operator works", {
dist <- dist_student_t(3, mu = 1)
expect_equal(density(-dist, 0.5), density(dist, -0.5))
})

test_that("transformed distributions pdf integrates to 1", {
dist_names <- c('norm', 'gamma', 'beta', 'chisq', 'exp',
'logis', 't', 'unif', 'weibull')
dist_args <- list(list(mean = 1, sd = 1), list(shape = 2, rate = 1),
list(shape1 = 3, shape2 = 5), list(df = 5),
list(rate = 1),
list(location = 1.5, scale = 1), list(df = 10),
list(min = 0, max = 1), list(shape = 3, scale = 1))
names(dist_args) <- dist_names
dist <- lapply(dist_names, function(x) do.call(dist_wrap, c(x, dist_args[[x]])))
dist <- do.call(c, dist)
dfun <- function(x, id, transform) density(get(transform)(dist[id]), x)[[1]]
twoexp <- function(x) 2^x
square <- function(x) x^2
mult2 <- function(x) 2*x
identity <- function(x) x
tol <- 1e-5
for (i in 1:length(dist)) {
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'identity')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'exp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'twoexp')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'mult2')$value, 1, tolerance = tol)
if (near(vec_data(support(dist[[i]]))$lim[[1]][1],0)) {
expect_equal(integrate(dfun, -Inf, 5, id = i, transform = 'log')$value, 1, tolerance = tol)
expect_equal(integrate(dfun, -Inf, Inf, id = i, transform = 'square')$value, 1, tolerance = tol)
}
}
})

Loading