Skip to content

Commit

Permalink
add elastic net algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
dchiu911 committed Feb 13, 2024
1 parent d1f62a0 commit 01ad2fd
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 2 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# splendid 0.4.0

* add elastic net penalty for multinomial logistic regression
* fix calculation of multi-class log loss metric
* increased minimum R version to 4.1.0
* fix warnings pertaining to deprecated functions
Expand Down
23 changes: 23 additions & 0 deletions R/classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ classification <- function(data, class, algorithms, rfe = FALSE, ova = FALSE,
mlr_glm = mlr_model(data, class, "mlr_glm"),
mlr_lasso = cv_mlr_model(data, class, "mlr_lasso", seed_alg),
mlr_ridge = cv_mlr_model(data, class, "mlr_ridge", seed_alg),
mlr_enet = enet_model(data, class, seed_alg),
mlr_nnet = mlr_model(data, class, "mlr_nnet"),
nnet = nnet_model(data, class),
nbayes = nbayes_model(data, class),
Expand Down Expand Up @@ -307,6 +308,28 @@ cv_mlr_model <- function(data, class, algorithms, seed_alg = NULL) {
}
}

#' Elastic net model
#' @noRd
enet_model <- function(data, class, seed_alg = NULL) {
if (!is.null(seed_alg)) set.seed(seed_alg)
if (!requireNamespace("caret", quietly = TRUE)) {
stop("Package \"caret\" is needed. Please install it.",
call. = FALSE)
} else {
caret::train(
x = as.matrix(data),
y = class,
method = "glmnet",
tuneLength = 10,
trControl = caret::trainControl(
method = "cv",
number = 5,
classProbs = TRUE
)
)
}
}

#' neural network model
#' @noRd
nnet_model <- function(data, class) {
Expand Down
1 change: 1 addition & 0 deletions R/splendid.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' * Generalized Linear Model with no penalization ("mlr_glm")
#' * GLM with LASSO penalty ("mlr_lasso")
#' * GLM with ridge penalty ("mlr_ridge")
#' * GLM with elastic net penalty ("mlr_enet")
#' * Neural Networks ("mlr_nnet")
#' * Neural Networks ("nnet")
#' * Naive Bayes ("nbayes")
Expand Down
4 changes: 2 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ globalVariables(".")

# Algorithm functions and classes
ALG.NAME <- c("pam", "svm", "rf", "lda", "slda", "sdda", "mlr_glm", "mlr_lasso",
"mlr_ridge", "mlr_nnet", "nnet", "nbayes", "adaboost",
"mlr_ridge", "mlr_enet", "mlr_nnet", "nnet", "nbayes", "adaboost",
"adaboost_m1", "xgboost", "knn")
ALG.CLASS <- c("pamrtrained", "train", "svm", "randomForest", "lda", "sda",
"cv.glmnet", "glmnet", "multinom", "nnet.formula", "naiveBayes",
"maboost", "boosting", "xgb.Booster", "knn")

# Algorithms that need all continuous predictors
ALG.CONT <- c("svm", "lda", "mlr_glm", "mlr_lasso", "mlr_ridge")
ALG.CONT <- c("svm", "lda", "mlr_glm", "mlr_lasso", "mlr_ridge", "mlr_enet")

#' Redirect any console printouts from print() or cat() to null device
#' @references
Expand Down
1 change: 1 addition & 0 deletions man/classification.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/splendid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/splendid_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 01ad2fd

Please sign in to comment.