From 03b38a12e6e38676db7e586733fa74597696bf07 Mon Sep 17 00:00:00 2001 From: HannaMeyer Date: Mon, 11 Mar 2024 17:14:32 +0100 Subject: [PATCH] first version of spatial error profile --- R/errorProfiles.R | 63 +++++++++++++++++++++++--------- inst/examples/ex_errorProfiles.R | 23 ++++++++---- 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/R/errorProfiles.R b/R/errorProfiles.R index 34364a93..36c49990 100644 --- a/R/errorProfiles.R +++ b/R/errorProfiles.R @@ -2,6 +2,7 @@ #' @description Performance metrics are calculated for moving windows of dissimilarity values based on cross-validated training data #' @param model the model used to get the AOA #' @param trainDI the result of \code{\link{trainDI}} or aoa object \code{\link{aoa}} +#' @param locations Optional. sf object for the training data used in model. Only used if variable=="geodist". Note that they must be in the same order as model$trainingData. #' @param variable Character. Which dissimilarity or distance measure to use for the error metric. Current options are "DI" or "LPD" #' @param multiCV Logical. Re-run model fitting and validation with different CV strategies. See details. #' @param window.size Numeric. Size of the moving window. See \code{\link{rollapply}}. @@ -31,7 +32,8 @@ errorProfiles <- function(model, - trainDI, + trainDI=NULL, + locations=NULL, variable = "DI", multiCV=FALSE, length.out = 10, @@ -47,13 +49,17 @@ errorProfiles <- function(model, trainDI = trainDI$parameters } + if(!is.null(locations)&variable=="geodist"){ + message("warning: Please ensure that the order of the locations matches to model$trainingData") + } + # get DIs and Errormetrics OR calculate new ones from multiCV if(!multiCV){ - preds_all <- get_preds_all(model, trainDI, variable) + preds_all <- get_preds_all(model, trainDI, locations, variable) } if(multiCV){ - preds_all <- multiCV(model, length.out, method, useWeight, variable) + preds_all <- multiCV(model, locations, length.out, method, useWeight, variable) } # train model between DI and Errormetric @@ -126,6 +132,8 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){ errormodel <- lm(metric ~ DI, data = performance) } else if (variable == "LPD") { errormodel <- lm(metric ~ LPD, data = performance) + } else if (variable=="geodist"){ + errormodel <- lm(metric ~ geodist, data = performance) } } if(calib=="scam"){ @@ -133,15 +141,22 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){ stop("Package \"scam\" needed for this function to work. Please install it.", call. = FALSE) } - if (variable == "DI") { + if (variable %in% c("DI","geodist")) { if (model$maximize){ # e.g. accuracy, kappa, r2 bs="mpd" }else{ bs="mpi" #e.g. RMSE } - errormodel <- scam::scam(metric~s(DI, k=k, bs=bs, m=m), - data=performance, - family=stats::gaussian(link="identity")) + if(variable=="DI"){ + errormodel <- scam::scam(metric~s(DI, k=k, bs=bs, m=m), + data=performance, + family=stats::gaussian(link="identity")) + }else if (variable=="geodist"){ + errormodel <- scam::scam(metric~s(geodist, k=k, bs=bs, m=m), + data=performance, + family=stats::gaussian(link="identity")) + } + } else if (variable == "LPD") { if (model$maximize){ # e.g. accuracy, kappa, r2 bs="mpi" @@ -149,13 +164,13 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){ bs="mpd" #e.g. RMSE } errormodel <- scam::scam(metric~s(LPD, k=k, bs=bs, m=m), - data=performance, - family=stats::gaussian(link="identity")) + data=performance, + family=stats::gaussian(link="identity")) } } if(calib=="exp"){ - if (variable == "DI") { - stop("Eexponential model currently only implemented for LPD") + if (variable %in% c("DI","geodist")) { + stop("Exponential model currently only implemented for LPD") } else if (variable == "LPD") { errormodel <- lm(metric ~ log(LPD), data = performance) } @@ -168,7 +183,7 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){ # MultiCV -multiCV <- function(model, length.out, method, useWeight, variable,...){ +multiCV <- function(model, locations, length.out, method, useWeight, variable,...){ preds_all <- data.frame() train_predictors <- model$trainingData[,-which(names(model$trainingData)==".outcome")] @@ -199,6 +214,10 @@ multiCV <- function(model, length.out, method, useWeight, variable,...){ trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight) } else if (variable == "LPD") { trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight, LPD = TRUE) + } else if (variable=="geodist"){ + tmp_gd_new <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut) + geodist_new <- tmp_gd_new[tmp_gd_new$what=="CV-distances","dist"] + } @@ -213,21 +232,26 @@ multiCV <- function(model, length.out, method, useWeight, variable,...){ preds_dat_tmp <- data.frame(preds,"LPD"=trainDI_new$trainLPD) preds_dat_tmp <- preds_dat_tmp[preds_dat_tmp$LPD > 0,] preds_all <- rbind(preds_all,preds_dat_tmp) + } else if (variable == "geodist"){ + preds_dat_tmp <- data.frame(preds,"geodist"=geodist_new) + preds_all <- rbind(preds_all,preds_dat_tmp) + # NO AOA used here } } - - attr(preds_all, "AOA_threshold") <- trainDI_new$threshold + if(variable%in%c("DI","LPD")){ + attr(preds_all, "AOA_threshold") <- trainDI_new$threshold + message(paste0("Note: multiCV=TRUE calculated new AOA threshold of ", round(trainDI_new$threshold, 5), + "\nThreshold is stored in the attributes, access with attr(error_model, 'AOA_threshold').", + "\nPlease refere to examples and details for further information.")) + } attr(preds_all, "variable") <- variable attr(preds_all, "metric") <- model$metric - message(paste0("Note: multiCV=TRUE calculated new AOA threshold of ", round(trainDI_new$threshold, 5), - "\nThreshold is stored in the attributes, access with attr(error_model, 'AOA_threshold').", - "\nPlease refere to examples and details for further information.")) return(preds_all) } # Get Preds all -get_preds_all <- function(model, trainDI, variable){ +get_preds_all <- function(model, trainDI, locations, variable){ if(is.null(model$pred)){ stop("no cross-predictions can be retrieved from the model. Train with savePredictions=TRUE or provide calibration data") @@ -252,6 +276,9 @@ get_preds_all <- function(model, trainDI, variable){ preds_all$LPD <- trainDI$trainLPD[!is.na(trainDI$trainLPD)] ## only take predictions from inside the AOA: preds_all <- preds_all[preds_all$LPD>0,] + } else if(variable=="geodist"){ + tmp_gd <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut) + preds_all$geodist <- tmp_gd[tmp_gd$what=="CV-distances","dist"] } attr(preds_all, "AOA_threshold") <- trainDI$threshold diff --git a/inst/examples/ex_errorProfiles.R b/inst/examples/ex_errorProfiles.R index fef05917..50a4e2dd 100644 --- a/inst/examples/ex_errorProfiles.R +++ b/inst/examples/ex_errorProfiles.R @@ -7,22 +7,21 @@ data(splotdata) - splotdata <- st_drop_geometry(splotdata) predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST")) - model <- caret::train(splotdata[,6:16], splotdata$Species_richness, ntree = 10, + model <- caret::train(st_drop_geometry(splotdata)[,6:16], splotdata$Species_richness, ntree = 10, trControl = trainControl(method = "cv", savePredictions = TRUE)) AOA <- aoa(predictors, model, LPD = TRUE, maxLPD = 1) - # DI ~ error + ### DI ~ error errormodel_DI <- errorProfiles(model, AOA, variable = "DI") plot(errormodel_DI) expected_error_DI = terra::predict(AOA$DI, errormodel_DI) plot(expected_error_DI) - # LPD ~ error + ### LPD ~ error errormodel_LPD <- errorProfiles(model, AOA, variable = "LPD") plot(errormodel_LPD) @@ -30,7 +29,19 @@ plot(expected_error_LPD) - # with multiCV = TRUE (for DI ~ error) + + ### geodist ~ error + errormodel_geodist = errorProfiles(model, locations=splotdata, + variable = "geodist") + plot(errormodel_geodist) + + dist <- terra::distance(predictors[[1]],vect(splotdata)) + names(dist) <- "geodist" + expected_error_DI <- terra::predict(dist, errormodel_geodist) + plot(expected_error_DI) + + + ### with multiCV = TRUE (for DI ~ error) errormodel_DI = errorProfiles(model, AOA, multiCV = TRUE, length.out = 3, variable = "DI") plot(errormodel_DI) @@ -43,7 +54,5 @@ plot(mask_aoa) - - }