-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinetune.py
122 lines (102 loc) · 4.44 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Import general libraries
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import hydra
from peft import LoraConfig, get_peft_model
from pathlib import Path
from omegaconf import OmegaConf
# Import custom libraries
from data_module import TextDatasetQA, custom_data_collator
from dataloader import CustomTrainer
from utils import get_model_identifiers_from_yaml, find_all_linear_names, print_trainable_parameters
# Import Habana libraries
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.hpu as hthpu
from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
# hydra function to load the configuration file
@hydra.main(version_base=None, config_path="config", config_name="finetune")
def main(cfg):
'''
This is the main function to fine-tune the model on the given dataset. the function loads the model and tokenizer as well as specifies the training arguments.
In this example, we use a custom trainer to train the model. The model is saved at the end of the training.
'''
if os.environ.get('LOCAL_RANK') is not None:
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
device_map = {'': local_rank}
set_seed(cfg.seed)
# Get the model identifiers from the model family
model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
model_id = model_cfg["hf_key"]
Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
#if master process
if os.environ.get('LOCAL_RANK') is None or local_rank == 0:
with open(f'{cfg.save_dir}/cfg.yaml', 'w') as f:
OmegaConf.save(cfg, f)
# Load the tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load the dataset
ft_dataset = TextDatasetQA(cfg.data_path, tokenizer=tokenizer, model_family = cfg.model_family, max_length=500, split=cfg.split, is_local_csv=True)
# variable for the training arguments
num_devices = hthpu.device_count()
max_steps = int(cfg.num_epochs*len(ft_dataset))//(cfg.batch_size*cfg.gradient_accumulation_steps*num_devices)
# Training arguments
training_args = GaudiTrainingArguments(
use_habana=True,
use_lazy_mode=False,
gaudi_config_name= cfg.gaudi_config_name,
per_device_train_batch_size=cfg.batch_size,
per_device_eval_batch_size=cfg.batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
warmup_steps=max(1, max_steps//cfg.num_epochs),
max_steps=max_steps,
learning_rate=cfg.lr,
bf16=True,
bf16_full_eval=True,
logging_steps=max(1,max_steps//20),
logging_dir=f'{cfg.save_dir}/logs',
output_dir=cfg.save_dir,
optim="paged_adamw_32bit",
save_steps=max_steps//5,
save_only_model=True,
ddp_find_unused_parameters= False,
evaluation_strategy="no",
deepspeed='config/ds_config.json', # deepspeed configuration usinn built-in args in Transformers
weight_decay = cfg.weight_decay,
seed = cfg.seed,
)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, trust_remote_code = True)
# Hot fix for https://discuss.huggingface.co/t/help-with-llama-2-finetuning-setup/50035
model.generation_config.do_sample = True
if model_cfg["gradient_checkpointing"] == "true":
model.gradient_checkpointing_enable()
# LoRA configuration
if cfg.LoRA.r != 0:
config = LoraConfig(
r=cfg.LoRA.r,
lora_alpha=cfg.LoRA.alpha,
target_modules=find_all_linear_names(model),
lora_dropout=cfg.LoRA.dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
model.enable_input_require_grads()
# Training using CustomTrainer from dataloader.py
trainer = CustomTrainer(
model=model,
train_dataset=ft_dataset,
eval_dataset=ft_dataset,
args=training_args,
data_collator=custom_data_collator,
)
model.config.use_cache = False # silence the warnings.
trainer.train()
#save the model
if cfg.LoRA.r != 0:
model = model.merge_and_unload()
model.save_pretrained(cfg.save_dir)
tokenizer.save_pretrained(cfg.save_dir)
if __name__ == "__main__":
main()