From bba75948314995fd99ffa0c573c9f5269ba4826b Mon Sep 17 00:00:00 2001 From: dwjak123lkdmaKOP Date: Mon, 13 Jan 2025 14:30:36 +0100 Subject: [PATCH] fix embeddings bug --- R/dnn.R | 8 ++++++-- R/utils.R | 8 +++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/R/dnn.R b/R/dnn.R index ba156e4..e4d5c58 100644 --- a/R/dnn.R +++ b/R/dnn.R @@ -242,6 +242,7 @@ dnn <- function(formula = NULL, Z_formula = tmp_data$Z_terms data = tmp_data$data + if(!is.null(Z)) { Z_args = list() @@ -250,7 +251,7 @@ dnn <- function(formula = NULL, } embeddings = list(inputs = ncol(Z), - dims = apply(Z, 2, function(z) length(levels(as.factor(z)))), # input embeddings + dims = tmp_data$Zlvls, # input embeddings args = Z_args ) } else { embeddings = NULL @@ -775,7 +776,10 @@ predict.citodnn <- function(object, newdata = NULL, else X <- torch::torch_tensor(stats::model.matrix(stats::as.formula(stats::delete.response(object$call$formula)), data.frame(newdata), xlev = object$data$xlvls)[,-1,drop=FALSE]) if(!is.null(object$model_properties$embeddings)) { - tmp = do.call(cbind, lapply(object$Z_formula, function(term) newdata[, term])) + tmp = do.call(cbind, lapply(object$Z_formula, function(term) { + if(!is.factor(newdata[[term]])) stop(paste0(term, " must be factor.")) + newdata[[term]] |> as.integer() + }) ) Z = torch::torch_tensor(tmp, dtype = torch::torch_long()) } } diff --git a/R/utils.R b/R/utils.R index 1f455e2..a7b1d31 100644 --- a/R/utils.R +++ b/R/utils.R @@ -668,13 +668,19 @@ get_X_Y = function(formula, X, Y, data) { out = list(X = X, Y = Y, formula = formula, data = data, Z = NULL, Z_terms = NULL) } else { terms = sapply(Specials$terms, as.character) + + Zlvls = + sapply(terms, function(i) { + if(!is.factor(data[,i])) stop("Embeddings must be passed as factor/categorical feature.") + return(nlevels(data[,i])) + }) Z = lapply(terms, function(i) { return(as.integer(data[,i])) }) Z = do.call(cbind, Z) colnames(Z) = terms - out = list(X = X, Y = Y, formula = formula, data = data, Z = Z, Z_terms = terms, Z_args = Specials$args) + out = list(X = X, Y = Y, formula = formula, data = data, Z = Z, Z_terms = terms, Z_args = Specials$args, Zlvls = Zlvls) } out$old_formula = old_formula return(out)