Skip to content

Commit

Permalink
fix embeddings bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilianPi committed Jan 13, 2025
1 parent 185f86b commit bba7594
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 6 additions & 2 deletions R/dnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ dnn <- function(formula = NULL,
Z_formula = tmp_data$Z_terms
data = tmp_data$data


if(!is.null(Z)) {

Z_args = list()
Expand All @@ -250,7 +251,7 @@ dnn <- function(formula = NULL,
}

embeddings = list(inputs = ncol(Z),
dims = apply(Z, 2, function(z) length(levels(as.factor(z)))), # input embeddings
dims = tmp_data$Zlvls, # input embeddings
args = Z_args )
} else {
embeddings = NULL
Expand Down Expand Up @@ -775,7 +776,10 @@ predict.citodnn <- function(object, newdata = NULL,
else X <- torch::torch_tensor(stats::model.matrix(stats::as.formula(stats::delete.response(object$call$formula)), data.frame(newdata), xlev = object$data$xlvls)[,-1,drop=FALSE])

if(!is.null(object$model_properties$embeddings)) {
tmp = do.call(cbind, lapply(object$Z_formula, function(term) newdata[, term]))
tmp = do.call(cbind, lapply(object$Z_formula, function(term) {
if(!is.factor(newdata[[term]])) stop(paste0(term, " must be factor."))
newdata[[term]] |> as.integer()
}) )
Z = torch::torch_tensor(tmp, dtype = torch::torch_long())
}
}
Expand Down
8 changes: 7 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,19 @@ get_X_Y = function(formula, X, Y, data) {
out = list(X = X, Y = Y, formula = formula, data = data, Z = NULL, Z_terms = NULL)
} else {
terms = sapply(Specials$terms, as.character)

Zlvls =
sapply(terms, function(i) {
if(!is.factor(data[,i])) stop("Embeddings must be passed as factor/categorical feature.")
return(nlevels(data[,i]))
})
Z =
lapply(terms, function(i) {
return(as.integer(data[,i]))
})
Z = do.call(cbind, Z)
colnames(Z) = terms
out = list(X = X, Y = Y, formula = formula, data = data, Z = Z, Z_terms = terms, Z_args = Specials$args)
out = list(X = X, Y = Y, formula = formula, data = data, Z = Z, Z_terms = terms, Z_args = Specials$args, Zlvls = Zlvls)
}
out$old_formula = old_formula
return(out)
Expand Down

0 comments on commit bba7594

Please sign in to comment.