Skip to content

Commit

Permalink
Add tests, fix bugs and implement matrixOps
Browse files Browse the repository at this point in the history
  • Loading branch information
UchidaMizuki committed Jun 16, 2024
1 parent fcc24ab commit 361b970
Show file tree
Hide file tree
Showing 20 changed files with 406 additions and 195 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Depends:
R (>= 2.10)
R (>= 4.4)
URL: https://github.com/UchidaMizuki/dibble,
https://uchidamizuki.github.io/dibble/
BugReports: https://github.com/UchidaMizuki/dibble/issues
4 changes: 0 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
S3method("!",ddf_col)
S3method("!",tbl_ddf)
S3method("$",tbl_ddf)
S3method("%*%",ddf_col)
S3method("%*%",default)
S3method("%*%",tbl_ddf)
S3method("[",tbl_ddf)
S3method("[[",tbl_ddf)
S3method("diag<-",ddf_col)
Expand Down Expand Up @@ -115,7 +112,6 @@ S3method(zeros,array)
S3method(zeros,ddf_col)
S3method(zeros,default)
S3method(zeros,tbl_ddf)
export("%*%")
export("diag<-")
export(apply)
export(as_dibble)
Expand Down
4 changes: 1 addition & 3 deletions R/broadcast.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ broadcast.default <- function(x,
is_dim_names(dim_names)
)

class <- class(x)
dim <- list_sizes_unnamed(dim_names)
x <- array(vec_recycle(x, prod(dim)),
dim = dim)

new_ddf_col(x, dim_names,
class = class)
new_ddf_col(x, dim_names)
} else {
broadcast(as_dibble(x), dim_names)
}
Expand Down
9 changes: 5 additions & 4 deletions R/dibble.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dibble <- function(...,
x
}
})
args <- vec_c(!!!args)
args <- list_unchop(args)

if (!is_named(args)) {
stopifnot(
Expand Down Expand Up @@ -399,9 +399,10 @@ find_index <- function(x, names) {
} else {
stopifnot(is_call(x))

out <- purrr::map(x[-1L], find_index,
names = names)
vec_c(!!!out)
out <- purrr::map(as.list(x[-1L]),
\(x) find_index(x,
names = names))
list_unchop(out)
}
}

Expand Down
4 changes: 2 additions & 2 deletions R/dim_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ is_dim_names <- function(x) {
}

union_dim_names <- function(x) {
x <- vec_c(!!!x)
x <- list_unchop(x)
nms <- names(x)
nms_unique <- unique(nms)
out <- purrr::map(nms_unique,
function(nm_unique) {
unique(vec_c(!!!unname(x[nms == nm_unique])))
unique(list_unchop(unname(x[nms == nm_unique])))
})
names(out) <- nms_unique
out
Expand Down
8 changes: 4 additions & 4 deletions R/extremes.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ pmax_dibble <- function(..., na.rm) {
as.array(broadcast(x, dim_names))
})

new_ddf_col(exec(base::pmax, !!!args),
new_ddf_col(exec(base::pmax, !!!args, na.rm = na.rm),
dim_names,
class = class)
class = setdiff(class, "tbl_ddf"))
}

#' @rdname extremes
Expand Down Expand Up @@ -101,7 +101,7 @@ pmin_dibble <- function(..., na.rm) {
as.array(broadcast(x, dim_names))
})

new_ddf_col(exec(base::pmin, !!!args),
new_ddf_col(exec(base::pmin, !!!args, na.rm = na.rm),
dim_names,
class = class)
class = setdiff(class, "tbl_ddf"))
}
73 changes: 0 additions & 73 deletions R/matrix.R
Original file line number Diff line number Diff line change
@@ -1,76 +1,3 @@
#' Matrix Multiplication
#'
#' Multiplies two matrices, if they are conformable.
#'
#' `%*%` overrides [`base::%*%`] to make it generic. The default method
#' calls the base version.
#'
#' @param x Numeric or complex dibble, matrices or vectors.
#' @param y Numeric or complex dibble, matrices or vectors.
#'
#' @return A dibble if x or y is a dibble of a matrix. A scalar numeric if both
#' x and y are dibbles of vectors. See [`base::%*%`] for the return value of the
#' default method.
#'
#' @seealso [`base::%*%`]
#'
#' @export
`%*%` <- function(x, y) {
UseMethod("%*%")
}

#' @export
`%*%.default` <- function(x, y) {
base::`%*%`(x, y)
}

#' @export
`%*%.tbl_ddf` <- function(x, y) {
matmult_dibble(x, y)
}

#' @export
`%*%.ddf_col` <- function(x, y) {
matmult_dibble(x, y)
}

matmult_dibble <- function(x, y) {
x <- as_ddf_col(x)
y <- as_ddf_col(y)

class <- class(x)
dim_names_x <- dimnames(x)
dim_names_y <- dimnames(y)

if (vec_size(dim_names_x) == 1L) {
x <- as.vector(x)
dim_names_x <- NULL
} else {
x <- as.matrix(x)
dim_names_x <- dim_names_x[1L]
}

if (vec_size(dim_names_y) == 1L) {
y <- as.vector(y)
dim_names_y <- NULL
} else {
y <- as.matrix(y)
dim_names_y <- dim_names_y[2L]
}

new_dim_names <- purrr::compact(c(dim_names_x, dim_names_y))

out <- x %*% y

if (vec_is_empty(new_dim_names)) {
as.vector(out)
} else {
dim(out) <- list_sizes_unnamed(new_dim_names)
new_ddf_col(out, new_dim_names,
class = class)
}
}

#' @export
t.tbl_ddf <- function(x) {
new_tbl_ddf(purrr::modify(undibble(x), t),
Expand Down
37 changes: 37 additions & 0 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,43 @@ Ops_dibble <- function(e1, e2) {
class = class)
}

matrixOps_dibble <- function(e1, e2) {
e1 <- as_ddf_col(e1)
e2 <- as_ddf_col(e2)

class <- class(e1)
dim_names_x <- dimnames(e1)
dim_names_y <- dimnames(e2)

if (vec_size(dim_names_x) == 1L) {
e1 <- as.vector(e1)
dim_names_x <- NULL
} else {
e1 <- as.matrix(e1)
dim_names_x <- dim_names_x[1L]
}

if (vec_size(dim_names_y) == 1L) {
e2 <- as.vector(e2)
dim_names_y <- NULL
} else {
e2 <- as.matrix(e2)
dim_names_y <- dim_names_y[2L]
}

new_dim_names <- purrr::compact(c(dim_names_x, dim_names_y))

out <- NextMethod()

if (vec_is_empty(new_dim_names)) {
as.vector(out)
} else {
dim(out) <- list_sizes_unnamed(new_dim_names)
new_ddf_col(out, new_dim_names,
class = class)
}
}

methods_dibble <- function(x, ...) {
x <- as_ddf_col(x)
NextMethod()
Expand Down
3 changes: 2 additions & 1 deletion R/tbl_ddf.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ mutate.tbl_ddf <- function(.data, ...) {
nms <- names(dots)

dim_names <- dimnames(.data)
class <- class(.data)
data <- as.list(.data)

.data <- undibble(.data)
Expand All @@ -150,7 +151,7 @@ mutate.tbl_ddf <- function(.data, ...) {
.data[[nm]] <- undibble(data_nm)
}
new_tbl_ddf(.data, dim_names,
class = class(.data))
class = class)
}

#' @importFrom dplyr select
Expand Down
7 changes: 5 additions & 2 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.onLoad <- function(...) {
.onLoad <- function(...) { # nocov start
as_dim_names <<- memoise::memoise(as_dim_names)
union_dim_names <<- memoise::memoise(union_dim_names)
broadcast_dim_names_impl <<- memoise::memoise(broadcast_dim_names_impl)
Expand All @@ -9,7 +9,10 @@
registerS3method("Ops", "ddf_col", Ops_dibble)
registerS3method("Ops", "tbl_ddf", Ops_dibble)

registerS3method("matrixOps", "ddf_col", matrixOps_dibble)
registerS3method("matrixOps", "tbl_ddf", matrixOps_dibble)

registerS3method("Math", "tbl_ddf", methods_dibble)

registerS3method("Summary", "tbl_ddf", methods_dibble)
}
} # nocov end
28 changes: 0 additions & 28 deletions man/grapes-times-grapes.Rd

This file was deleted.

9 changes: 8 additions & 1 deletion tests/testthat/test-apply.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("apply", {
test_that("apply() works", {
arr <- array(1:24, 2:4,
list(axis1 = 1:2,
axis2 = 1:3,
Expand All @@ -18,4 +18,11 @@ test_that("apply", {

test_apply(ddf_col)
test_apply(tbl_ddf)

# Test that the class is preserved
class(ddf_col) <- c("my_class", class(ddf_col))
expect_s3_class(apply(ddf_col, 2, sum), c("my_class", "ddf_col"))

class(tbl_ddf) <- c("my_class", class(tbl_ddf))
expect_s3_class(apply(tbl_ddf, 2, sum), c("my_class", "ddf_col"))
})
27 changes: 21 additions & 6 deletions tests/testthat/test-broadcast.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("broadcast", {
test_that("broadcast() works", {
x <- broadcast(1:2,
list(axis1 = letters[1:2]))
expect_equal(as.array(x),
Expand All @@ -9,22 +9,37 @@ test_that("broadcast", {
expect_equal(as.array(y),
array(1:3, 3))

expect_silent(broadcast(x * y, c("axis1", "axis2")))
z <- broadcast(1:4,
list(axis3 = letters[1:4]))
expect_equal(as.array(z),
array(1:4, 4))

xy <- expect_silent(broadcast(x * y, c("axis1", "axis2")))
expect_equal(as.array(xy), outer(as.array(x), as.array(y)))

xyz <- expect_silent(broadcast(x * y * z, c("axis1", "axis2", "axis3")))
expect_equal(as.array(xyz), outer(outer(as.array(x), as.array(y)), as.array(z)))

# Test that the class is preserved
class(x) <- c("my_class", class(x))
xy <- broadcast(x * y, c("axis1", "axis2"))
expect_s3_class(xy, class(x))
})

test_that("broadcast-warn", {
test_that("broadcast() warns", {
x <- broadcast(1:4,
list(axis1 = 1:2,
axis2 = 1:2))
y <- x

expect_silent(x * y)
xy1 <- expect_silent(x * y)

y <- t(x)

expect_warning(x * y)
expect_silent(broadcast(x * y,
c("axis1", "axis2")))
xy2 <- expect_silent(broadcast(x * y,
c("axis1", "axis2")))
expect_equal(xy1, xy2)

y <- broadcast(1:6,
list(axis1 = 1:2,
Expand Down
Loading

0 comments on commit 361b970

Please sign in to comment.