From e600cc6477d564d190d169ad8f40303d9d8f0de4 Mon Sep 17 00:00:00 2001 From: Uchida Mizuki Date: Sun, 8 Dec 2024 09:42:55 +0900 Subject: [PATCH] Fix matrixOps #26 --- R/methods.R | 28 +++++++++++++++++++--------- tests/testthat/test-matrix.R | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/R/methods.R b/R/methods.R index ea6d7b5..9d3e633 100644 --- a/R/methods.R +++ b/R/methods.R @@ -45,26 +45,36 @@ matrixOps_dibble <- function(e1, e2) { e2 <- as_ddf_col(e2) class <- class(e1) - dim_names_x <- dimnames(e1) - dim_names_y <- dimnames(e2) + dim_names_e1 <- dimnames(e1) + dim_names_e2 <- dimnames(e2) - if (vec_size(dim_names_x) == 1L) { + size_dim_names_e1 <- vec_size(dim_names_e1) + size_dim_names_e2 <- vec_size(dim_names_e2) + + new_dim_name <- union_dim_names(list(dim_names_e1[size_dim_names_e1], dim_names_e2[1])) + dim_names_e1[size_dim_names_e1] <- new_dim_name + dim_names_e2[1] <- new_dim_name + + e1 <- broadcast(e1, dim_names_e1) + e2 <- broadcast(e2, dim_names_e2) + + if (vec_size(dim_names_e1) == 1L) { e1 <- as.vector(e1) - dim_names_x <- NULL + dim_names_e1 <- NULL } else { e1 <- as.matrix(e1) - dim_names_x <- dim_names_x[1L] + dim_names_e1 <- dim_names_e1[1L] } - if (vec_size(dim_names_y) == 1L) { + if (vec_size(dim_names_e2) == 1L) { e2 <- as.vector(e2) - dim_names_y <- NULL + dim_names_e2 <- NULL } else { e2 <- as.matrix(e2) - dim_names_y <- dim_names_y[2L] + dim_names_e2 <- dim_names_e2[2L] } - new_dim_names <- purrr::compact(c(dim_names_x, dim_names_y)) + new_dim_names <- purrr::compact(c(dim_names_e1, dim_names_e2)) out <- NextMethod() diff --git a/tests/testthat/test-matrix.R b/tests/testthat/test-matrix.R index 3a81d1b..7360256 100644 --- a/tests/testthat/test-matrix.R +++ b/tests/testthat/test-matrix.R @@ -1,4 +1,10 @@ test_that("`%*%`() works", { + rev_axis <- function(x, axis) { + dim_names <- dimnames(x) + dim_names[[axis]] <- rev(dim_names[[axis]]) + broadcast(x, dim_names) + } + # mat %*% mat mat_x <- matrix(1:9, 3, dimnames = list(axis1 = 1:3, @@ -9,10 +15,14 @@ test_that("`%*%`() works", { ddf_x <- as_dibble(mat_x) ddf_y <- as_dibble(mat_y) expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% mat_y)) + expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% mat_y)) + expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% mat_y)) ddf_x <- dibble(x = ddf_x) ddf_y <- dibble(x = ddf_y) expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% mat_y)) + expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% mat_y)) + expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% mat_y)) # vec %*% mat vec_x <- array(1:3, 3, @@ -23,6 +33,8 @@ test_that("`%*%`() works", { ddf_x <- as_dibble(vec_x) ddf_y <- as_dibble(mat_y) expect_equal(as.matrix(ddf_x %*% ddf_y), t(unname(vec_x %*% mat_y))) + expect_equal(as.matrix(rev_axis(ddf_x, 1) %*% ddf_y), t(unname(vec_x %*% mat_y))) + expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), t(unname(vec_x %*% mat_y))) # mat %*% vec mat_x <- matrix(1:9, 3, @@ -33,6 +45,8 @@ test_that("`%*%`() works", { ddf_x <- as_dibble(mat_x) ddf_y <- as_dibble(vec_y) expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% vec_y)) + expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% vec_y)) + expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% vec_y)) # vec %*% vec vec_x <- array(1:3, 3, @@ -42,6 +56,8 @@ test_that("`%*%`() works", { ddf_x <- as_dibble(vec_x) ddf_y <- as_dibble(vec_y) expect_equal(ddf_x %*% ddf_y, as.vector(vec_x %*% vec_y)) + expect_equal(rev_axis(ddf_x, 1) %*% ddf_y, as.vector(vec_x %*% vec_y)) + expect_equal(ddf_x %*% rev_axis(ddf_y, 1), as.vector(vec_x %*% vec_y)) }) test_that("t() works", {