Skip to content

Commit

Permalink
various bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
macartan committed Oct 21, 2024
1 parent 8ebc0d6 commit 6c25661
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 32 deletions.
54 changes: 30 additions & 24 deletions R/make_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,21 @@ make_model <- function(statement,
}

# generate DAG
x <- make_dag(statement)
.dag <- make_dag(statement)

# clean dag statement
statement <- paste(paste(x$v, x$e, x$w), collapse = "; ")

if (nrow(x) == 0) {
dag <- data.frame(v = statement, w = NA)
} else {
dag <- x |>
dplyr::filter(e == "->") |>
statement <- ifelse(nrow(.dag) ==1 & all(is.na(.dag$e)),
statement,
paste(paste(.dag$v, .dag$e, .dag$w), collapse = "; ")
)

# parent child data.frame
dag <- .dag |>
dplyr::filter(e == "->" | is.na(e)) |>
dplyr::select(v, w)
}

# disallow dangling confound e.g. X -> M <-> Y
if (nrow(x) > 0 && any(!(unlist(x[, 1:2]) %in% unlist(dag)))) {
# disallow dangling confound e.g. X -> M <-> Y (single nodes allowed)
if (any(!(unlist(.dag[, 1:2]) %in% unlist(dag)))) {
stop("Graph should not contain isolates.")
}

Expand Down Expand Up @@ -255,10 +255,10 @@ make_model <- function(statement,

# Add confounds if any provided

if (any(x$e == "<->")) {
if (grepl("<->", statement)) {
confounds <- NULL

z <- x |>
z <- .dag |>
dplyr::filter(e == "<->") |>
dplyr::select(v, w)
z$v <- as.character(z$v)
Expand Down Expand Up @@ -435,13 +435,13 @@ clean_statement <- function(statement) {

#' Helper to run a causal statement specifying a DAG into a \code{data.frame} of
#' pairwise parent child relations between nodes specified by a respective edge.
#' This function reproduces the result of the following \code{dagitty} operations:
#' This function can substitute for the following \code{dagitty} operations:
#' \code{dagitty::dagitty() |> dagitty::edges()}
#'
#' @param statement character string. Statement describing causal
#' relations using \code{dagitty} syntax. Only directed relations are
#' permitted. For instance "X -> Y" or "X1 -> Y <- X2; X1 -> X2"
#' @return a \code{data.frame} with columns v,w,e specifying parent, child and
#' @return a \code{data.frame} with columns v, w, e specifying parent, child and
#' edge respectively
#' @keywords internal

Expand Down Expand Up @@ -485,25 +485,31 @@ make_dag <- function(statement) {
dag <- dags[!vapply(dags, is.null, logical(1))]

if(length(dag) == 0) {
dag <- data.frame()
# Single node case
data.frame(v = statement, w = NA, e = NA)

} else {
dag <- dag |>

dag |>
dplyr::bind_rows() |>
dplyr::arrange(v)
dplyr::arrange(v) |>
distinct() |>
remove_duplicates()
}

# remove duplicates
remove_duplicates(distinct(dag))
}



remove_duplicates <- function(df) {
# Create a normalized version of v and w
df$normalized_v <- pmin(df$v, df$w)
df$normalized_w <- pmax(df$v, df$w)
if(nrow(df) == 1) return(df)

df <- df |> mutate(
normalized_v = ifelse(e == "<->", pmin(v, w), v),
normalized_w = ifelse(e == "<->", pmax(v, w), w)
)
# Remove duplicates (eg X<-Y; Y<->X)

# Remove duplicates
df <- df[!duplicated(df[, c("normalized_v", "normalized_w", "e")]), ]

df[, c("v", "w", "e")]
Expand Down
11 changes: 9 additions & 2 deletions R/plot_dag.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,16 @@ plot_model <- function(model = NULL,
dplyr::rename(x = v, y = w) |>
dplyr::mutate(weight = 1)

# Figure our ggraph data structure
coords <- (dag |> ggraph::ggraph(layout = "sugiyama"))$data |>
# Figure out ggraph data structure
if(nrow(dag)==1 & all(is.na(dag$e))) {
# Special case with one node
coords <- data.frame(x=0, y=0, name = dag$x)
dag$e <- dag$y <- "NA"
} else {
# Usual case
coords <- (dag |> ggraph::ggraph(layout = "sugiyama"))$data |>
dplyr::select(x, y, name)
}

# reorder nodes to match model ordering
.r <- match(coords$name, model$nodes)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_get_posterior_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
context("Testing get_posterior_distribution")

test_that("get_posterior_distribution triggers deprecation warning", {
model <- make_model("X") |> update_model()
model <- make_model("X -> Y") |> update_model()
expect_warning(CausalQueries:::get_posterior_distribution(model = model),
regexp = "is deprecated")
})
14 changes: 12 additions & 2 deletions tests/testthat/test_plot_dag.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ testthat::skip_on_cran()


testthat::test_that(
desc = "Testing labels",
desc = "Testing basic functioning",
code = {
model <- make_model("X -> M -> Y; X -> Y")
pdf(file = NULL)
Expand All @@ -16,7 +16,17 @@ testthat::test_that(


testthat::test_that(
desc = "Testing setting labels",
desc = "Testing plot isolate",
code = {
model <- make_model("X")
pdf(file = NULL)
expect_silent(plot(model))
dev.off()
})


testthat::test_that(
desc = "Testing setting coordinates",
code = {
model <- make_model("X -> K -> Y")
x <- c(1, 2, 3)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_restrictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ for(i in length(dags)){
model <- make_model(dags[i])
rest_model <- set_restrictions(model, monotonicity[i])
expect_true(
length(get_nodal_types(model)$Y) > length(get_nodal_types(rest_model)$Y)
(length(get_nodal_types(model)$Y) - length(get_nodal_types(rest_model)$Y) == 4)
)
}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_set_confound.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ testthat::test_that(
make_model('X -> Y -> W') |>
set_confound(list('Y <-> X', 'X <-> W'))

expect_identical(model$statement, "X -> Y -> W; Y <-> X; W <-> X")
expect_identical(model$statement, "X -> Y; Y -> W; Y <-> X; W <-> X")

expect_message(
make_model('X -> Y -> W') |> set_confound(),
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_set_prior_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ testthat::test_that(
desc = "Deprecated get_prior_distribution.",

code = {
expect_warning(make_model("X") |>
expect_warning(make_model("X -> Y") |>
CausalQueries:::get_prior_distribution(n_draws = 10) |>
ncol())
}
Expand Down

0 comments on commit 6c25661

Please sign in to comment.