Skip to content

Commit

Permalink
[2024.3]
Browse files Browse the repository at this point in the history
  • Loading branch information
Bruce committed Mar 18, 2024
1 parent ce64c22 commit ae12ace
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 56 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: FMAT
Title: The Fill-Mask Association Test
Version: 2024.3
Date: 2024-03-16
Date: 2024-03-18
Authors@R:
c(person(given = "Han-Wu-Shuang",
family = "Bao",
Expand Down
67 changes: 49 additions & 18 deletions R/FMAT.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ transformers_init = function() {
reticulate::py_capture_output({
os = reticulate::import("os")
os$environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os$environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

torch = reticulate::import("torch")
torch.ver = torch$`__version__`
torch.cuda = torch$cuda$is_available()
if(torch.cuda) {
cuda.ver = torch$cuda_version
gpu.info = paste("GPU Devices:", paste(torch$cuda$get_device_name(), collapse=", "))
gpu.info = paste("GPU (Device):", paste(torch$cuda$get_device_name(), collapse=", "))
} else {
cuda.ver = "NULL"
gpu.info = paste("(To use GPU, install PyTorch with CUDA support,",
Expand All @@ -137,12 +138,28 @@ transformers_init = function() {
}


find_cached_models = function(cache.folder) {
models.name = list.files(cache.folder, "^models--")
if(length(models.name) > 0) {
models.size = sapply(paste0(cache.folder, "/", models.name), function(folder) {
models.file = list.files(folder, pattern="(model.safetensors$|pytorch_model.bin$|tf_model.h5$)", recursive=TRUE, full.names=TRUE)
paste(paste0(sprintf("%.0f", file.size(models.file) / 1024^2), " MB"), collapse=" / ")
})
models.name = str_replace_all(str_remove(models.name, "^models--"), "--", "/")
models.info = data.frame(Size=models.size, row.names=models.name)
} else {
models.info = NULL
}
return(models.info)
}


#### BERT ####


#' DownLoad BERT models.
#' Download BERT models.
#'
#' DownLoad BERT models to local cache folder "\%USERPROFILE\%/.cache/huggingface".
#' Download and save BERT models to local cache folder "\%USERPROFILE\%/.cache/huggingface".
#'
#' @param models Model names at \href{https://huggingface.co/models}{HuggingFace}.
#'
Expand All @@ -151,7 +168,6 @@ transformers_init = function() {
#'
#' @return
#' No return value.
#' Model files will be saved in "\%USERPROFILE\%/.cache/huggingface".
#'
#' @seealso
#' \code{\link{FMAT_load}}
Expand All @@ -160,22 +176,33 @@ transformers_init = function() {
#' \dontrun{
#' model.names = c("bert-base-uncased", "bert-base-cased")
#' BERT_download(model.names)
#'
#' BERT_download() # check downloaded models
#' }
#'
#' @export
BERT_download = function(models) {
BERT_download = function(models=NULL) {
transformers = transformers_init()
lapply(models, function(model) {
cli::cli_h1("Downloading model \"{m}\"")
cli::cli_text("Downloading configuration...")
config = transformers$AutoConfig$from_pretrained(model)
cli::cli_text("Downloading tokenizer...")
tokenizer = transformers$AutoTokenizer$from_pretrained(model)
cli::cli_text("Downloading model...")
model = transformers$AutoModel$from_pretrained(model)
cli::cli_alert_success("Successfully downloaded model \"{model}\"")
gc()
})
if(!is.null(models)) {
lapply(models, function(model) {
cli::cli_h1("Downloading model {.val {model}}")
cli::cli_text("Downloading configuration...")
transformers$AutoConfig$from_pretrained(model)
cli::cli_text("Downloading tokenizer...")
transformers$AutoTokenizer$from_pretrained(model)
cli::cli_text("Downloading model...")
transformers$AutoModel$from_pretrained(model)
cli::cli_alert_success("Successfully downloaded model {.val {model}}")
gc()
})
}
cache.folder = str_replace_all(transformers$TRANSFORMERS_CACHE, "\\\\", "/")
cache.sizegb = sum(file.size(list.files(cache.folder, recursive=TRUE, full.names=TRUE))) / 1024^3
local.models = find_cached_models(cache.folder)
cli::cli_h2("Downloaded models:")
print(local.models)
cat("\n")
cli::cli_alert_success("Downloaded models saved at {.path {cache.folder}} ({sprintf('%.2f', cache.sizegb)} GB)")
}


Expand Down Expand Up @@ -228,15 +255,19 @@ BERT_download = function(models) {
#' @export
FMAT_load = function(models, gpu=FALSE) {
transformers = transformers_init()
cache.folder = str_replace_all(transformers$TRANSFORMERS_CACHE, "\\\\", "/")
device = gpu_to_device(gpu)
check_gpu_enabled(device)
cli::cli_text("Loading models...")
cli::cli_text("Loading models from {.path {cache.folder}}...")
fms = lapply(models, function(model) {
t0 = Sys.time()
reticulate::py_capture_output({
fill_mask = transformers$pipeline("fill-mask", model=model, device=device)
})
cli::cli_alert_success("{model} ({dtime(t0)})")
if(device %in% c(-1L, "cpu"))
cli::cli_alert_success("{model} ({dtime(t0)}) - CPU")
else
cli::cli_alert_success("{model} ({dtime(t0)}) - GPU (device id = {device})")
return(list(model.name=model, fill.mask=fill_mask))
})
names(fms) = models
Expand Down
Loading

0 comments on commit ae12ace

Please sign in to comment.