Skip to content

Commit

Permalink
deploy weibull_pdf function with lagged effect
Browse files Browse the repository at this point in the history
- added weibull_pdf as option into adstock_weibull() and adapted all related functions
- improved documentation and readme
  • Loading branch information
gufengzhou committed Oct 22, 2021
1 parent b1450da commit 1140a3f
Show file tree
Hide file tree
Showing 17 changed files with 720 additions and 166 deletions.
1 change: 1 addition & 0 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ importFrom(reticulate,virtualenv_create)
importFrom(stats,AIC)
importFrom(stats,BIC)
importFrom(stats,coef)
importFrom(stats,dweibull)
importFrom(stats,end)
importFrom(stats,lm)
importFrom(stats,model.matrix)
Expand Down
2 changes: 1 addition & 1 deletion R/R/allocator.R
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ robyn_allocator <- function(robyn_object = NULL,
## get adstock parameters for each channel
if (InputCollect$adstock == "geometric") {
getAdstockHypPar <- unlist(dt_hyppar[, .SD, .SDcols = na.omit(str_extract(names(dt_hyppar), ".*_thetas"))])
} else if (InputCollect$adstock == "weibull") {
} else if (InputCollect$adstock %in% c("weibull_cdf", "weibull_pdf")) {
getAdstockHypPar <- unlist(dt_hyppar[, .SD, .SDcols = na.omit(str_extract(names(dt_hyppar), ".*_shapes|.*_scales"))])
}

Expand Down
6 changes: 4 additions & 2 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,11 @@ check_windows <- function(dt_input, date_var, all_media, window_start, window_en
}

check_adstock <- function(adstock) {
if (!adstock %in% c("geometric", "weibull")) {
stop("'adstock' must be 'geometric' or 'weibull'")
if (adstock == "weibull") adstock <- "weibull_cdf"
if (!adstock %in% c("geometric", "weibull_cdf", "weibull_pdf")) {
stop("'adstock' must be 'geometric', 'weibull_cdf' or 'weibull_pdf'")
}
return(adstock)
}

check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL, all_media = NULL) {
Expand Down
2 changes: 1 addition & 1 deletion R/R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#' virtualenv_create py_install use_virtualenv
#' @importFrom rPref low psel
#' @importFrom stats AIC BIC coef end lm model.matrix na.omit nls.control
#' predict pweibull quantile qunif start
#' predict pweibull dweibull quantile qunif start
#' @importFrom stringr str_detect str_remove str_which str_extract str_replace
#' @importFrom utils askYesNo head setTxtProgressBar txtProgressBar
"_PACKAGE"
Expand Down
27 changes: 16 additions & 11 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,17 @@
#' order and same length as \code{organic_vars}.
#' @param factor_vars Character vector. Specify which of the provided
#' variables in organic_vars or context_vars should be forced as a factor
#' @param adstock Character. Choose any of \code{c("geometric", "weibull")}.
#' Weibull adtock is a two-parametric function and thus more flexible, but
#' takes longer time than the traditional geometric one-parametric function.
#' Time estimation: with geometric adstock, 2000 iterations * 5 trials on 8
#' cores, it takes less than 30 minutes. Weibull takes at least twice as
#' much time.
#' @param adstock Character. Choose any of \code{c("geometric", "weibull_cdf",
#' "weibull_pdf")}. Weibull adtock is a two-parametric function and thus more
#' flexible, but takes longer time than the traditional geometric one-parametric
#' function. CDF, or cumulative density function of the Weibull function allows
#' changing decay rate over time in both C and S shape, while the peak value will
#' always stay at the first period, meaning no lagged effect. PDF, or the
#' probability density function, enables peak value occuring after the first
#' period when shape >=1, allowing lagged effect. Run \code{plot_adstock()} to
#' see the difference visually. Time estimation: with geometric adstock, 2000
#' iterations * 5 trials on 8 cores, it takes less than 30 minutes. Both Weibull
#' options take up to twice as much time.
#' @param hyperparameters List containing hyperparameter lower and upper bounds.
#' Names of elements in list must be identical to output of \code{hyper_names()}
#' @param window_start Character. Set start date of modelling period.
Expand Down Expand Up @@ -237,7 +242,7 @@ robyn_inputs <- function(dt_input = NULL,
rollingWindowLength <- windows$rollingWindowLength

## check adstock
check_adstock(adstock)
adstock <- check_adstock(adstock)

## check hyperparameters (if passed)
check_hyperparameters(hyperparameters, adstock, all_media)
Expand Down Expand Up @@ -335,7 +340,7 @@ robyn_inputs <- function(dt_input = NULL,
#' to get correct hyperparameter names. All names in hyperparameters must
#' equal names from \code{hyper_names()}, case sensitive.
#' \item{Get guidance for setting hyperparameter bounds:
#' For geometric adstock, use theta, alpha & gamma. For weibull adstock,
#' For geometric adstock, use theta, alpha & gamma. For both weibull adstock options,
#' use shape, scale, alpha, gamma.}
#' \itemize{
#' \item{Theta: }{In geometric adstock, theta is decay rate. guideline for usual media genre:
Expand Down Expand Up @@ -363,7 +368,7 @@ robyn_inputs <- function(dt_input = NULL,
#' }
#'
#' @param adstock A character. Default to \code{InputCollect$adstock}.
#' Accepts "geometric" or "weibull"
#' Accepts "geometric", "weibull_cdf" or "weibull_pdf"
#' @param all_media A character vector. Default to \code{InputCollect$all_media}.
#' Includes \code{InputCollect$paid_media_vars} and \code{InputCollect$organic_vars}.
#' @examples
Expand Down Expand Up @@ -427,11 +432,11 @@ robyn_inputs <- function(dt_input = NULL,
#' }
#' @export
hyper_names <- function(adstock, all_media) {
check_adstock(adstock)
adstock <- check_adstock(adstock)
global_name <- c("thetas", "shapes", "scales", "alphas", "gammas", "lambdas")
if (adstock == "geometric") {
local_name <- sort(apply(expand.grid(all_media, global_name[global_name %like% "thetas|alphas|gammas"]), 1, paste, collapse = "_"))
} else if (adstock == "weibull") {
} else if (adstock %in% c("weibull_cdf","weibull_pdf")) {
local_name <- sort(apply(expand.grid(all_media, global_name[global_name %like% "shapes|scales|alphas|gammas"]), 1, paste, collapse = "_"))
}
return(local_name)
Expand Down
33 changes: 22 additions & 11 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ robyn_run <- function(InputCollect,
))

for (ngt in 1:InputCollect$trials) {
message(paste(" Running trial nr.", ngt, "\n"))
message(paste(" Running trial nr.", ngt))
model_output <- robyn_mmm(
hyper_collect = InputCollect$hyperparameters,
InputCollect = InputCollect,
Expand All @@ -158,9 +158,8 @@ robyn_run <- function(InputCollect,
num_coef0_mod <- model_output$resultCollect$decompSpendDist[decomp.rssd == Inf, uniqueN(paste0(iterNG, "_", iterPar))]
num_coef0_mod <- ifelse(num_coef0_mod > InputCollect$iterations, InputCollect$iterations, num_coef0_mod)
message("This trial contains ", num_coef0_mod, " iterations with all 0 media coefficient. Please reconsider your media variable choice if the pareto choices are unreasonable.
\nRecommendations are: \n1. increase hyperparameter ranges for 0-coef channels on theta (max.reco. c(0, 0.9) ) and gamma (max.reco. c(0.1, 1) ) to give Robyn more freedom\n2. split media into sub-channels, and/or aggregate similar channels, and/or introduce other media\n3. increase trials to get more samples\n")
}

\nRecommendations are: \n1. increase hyperparameter ranges for 0-coef channels to give Robyn more freedom\n2. split media into sub-channels, and/or aggregate similar channels, and/or introduce other media\n3. increase trials to get more samples\n")
}
model_output["trial"] <- ngt
model_output_collect[[ngt]] <- model_output
}
Expand Down Expand Up @@ -498,10 +497,14 @@ robyn_run <- function(InputCollect,
if (InputCollect$adstock == "geometric") {
theta <- hypParam[paste0(InputCollect$all_media[med], "_thetas")]
x_list <- adstock_geometric(x = m, theta = theta)
} else if (InputCollect$adstock == "weibull") {
} else if (InputCollect$adstock == "weibull_cdf") {
shape <- hypParam[paste0(InputCollect$all_media[med], "_shapes")]
scale <- hypParam[paste0(InputCollect$all_media[med], "_scales")]
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "cdf")
} else if (InputCollect$adstock == "weibull_pdf") {
shape <- hypParam[paste0(InputCollect$all_media[med], "_shapes")]
scale <- hypParam[paste0(InputCollect$all_media[med], "_scales")]
x_list <- adstock_weibull(x = m, shape = shape, scale = scale)
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "pdf")
}
m_adstocked <- x_list$x_decayed
dt_transformAdstock[, (med_select) := m_adstocked]
Expand Down Expand Up @@ -1035,13 +1038,17 @@ robyn_mmm <- function(hyper_collect,
if (adstock == "geometric") {
theta <- hypParamSam[paste0(all_media[v], "_thetas")]
x_list <- adstock_geometric(x = m, theta = theta)
} else if (adstock == "weibull") {
} else if (adstock == "weibull_cdf") {
shape <- hypParamSam[paste0(all_media[v], "_shapes")]
scale <- hypParamSam[paste0(all_media[v], "_scales")]
x_list <- adstock_weibull(x = m, shape = shape, scale = scale)
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = rollingWindowLength, type = "cdf")
} else if (adstock == "weibull_pdf") {
shape <- hypParamSam[paste0(all_media[v], "_shapes")]
scale <- hypParamSam[paste0(all_media[v], "_scales")]
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = rollingWindowLength, type = "pdf")
} else {
break
print("adstock parameter must be geometric or weibull")
print("adstock parameter must be geometric, weibull_cdf or weibull_pdf")
}

m_adstocked <- x_list$x_decayed
Expand Down Expand Up @@ -1540,10 +1547,14 @@ robyn_response <- function(robyn_object = NULL,
if (adstock == "geometric") {
theta <- dt_hyppar[solID == select_model, get(paste0(paid_media_var, "_thetas"))]
x_list <- adstock_geometric(x = mediaVar, theta = theta)
} else if (adstock == "weibull") {
} else if (adstock == "weibull_cdf") {
shape <- dt_hyppar[solID == select_model, get(paste0(paid_media_var, "_shapes"))]
scale <- dt_hyppar[solID == select_model, get(paste0(paid_media_var, "_scales"))]
x_list <- adstock_weibull(x = mediaVar, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "cdf")
} else if (adstock == "weibull_pdf") {
shape <- dt_hyppar[solID == select_model, get(paste0(paid_media_var, "_shapes"))]
scale <- dt_hyppar[solID == select_model, get(paste0(paid_media_var, "_scales"))]
x_list <- adstock_weibull(x = mediaVar, shape = shape, scale = scale)
x_list <- adstock_weibull(x = mediaVar, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "pdf")
}
m_adstocked <- x_list$x_decayed

Expand Down
Loading

0 comments on commit 1140a3f

Please sign in to comment.