-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
133 lines (107 loc) · 3.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import json
import pathlib
from tqdm import tqdm
import torch
import numpy as np
# TODO add llamatokenizer where needed
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, LlamaTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from quality_metrics.common import get_hf_dataset, formatting_prompts_func
from utils_test import (
prompt_train,
Prompt_Intstruction,
)
def train(config, run_id):
print("\n=============\n ")
print("Path of training data ", config.train_path)
print("\n=============\n ")
print("Save path ", f'outputs/{run_id}')
print("\n=============\n ")
if config.slurm.cluster == 'jz':
os.environ["WANDB_DISABLED"] = "True"
os.environ['TRANSFORMERS_OFFLINE'] = "1"
os.environ['WANDB_MODE'] = "offline"
os.environ["WANDB_PROJECT"] = 'measure_pp'
os.environ['WANDB_CACHE_DIR'] = 'wandb_cache'
os.environ['TOKENIZERS_PARALLELISM'] = "True"
if config.slurm.gpu == "v100":
bf16 = False
fp16 = True
else:
bf16 = True
fp16 = False
run_name_wandb = f'{config.model_id}_{run_id}'
model_save_dir = f'outputs/{run_id}/save_models'
name_json_save_all = f'outputs/{run_id}/save_results/passk.json'
model_name = config.model_id.split('/')[-1]
print(f'run_name_wandb {run_name_wandb}')
print(os.getcwd())
# hf way to load json dataset
# data_path =
with open(config.train_path, encoding="utf-8") as f:
dataset = json.load(f)
# with open('data/dataset.json', 'r') as f:
# dataset = json.load(f)
dataset = get_hf_dataset(dataset)
# TODO maybe move this in eval
save_all_dir = str(pathlib.Path(name_json_save_all).parent)
if not os.path.exists(save_all_dir):
os.makedirs(save_all_dir)
if not os.path.exists(name_json_save_all):
# Create a new JSON file with some sample data
sample_data = {}
with open(name_json_save_all, 'w') as file:
json.dump(sample_data, file, indent=4)
if not os.path.exists('save_results'):
os.makedirs('save_results')
if not os.path.exists('save_sol'):
os.makedirs('save_sol')
tokenizer = AutoTokenizer.from_pretrained(config.model_id)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
config.model_id,
device_map="auto",
)
lr_scheduler_type= "cosine"
warmup_ratio=0.1
if config.sol==True:
response_template= "Solution 1:"
else:
response_template = "Problem 1:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False)
learning_rate=1e-5
training_arguments=TrainingArguments(
per_device_train_batch_size=config.batch_size,
gradient_accumulation_steps=config.accum_step,
run_name=run_name_wandb,
save_strategy="no",
warmup_ratio=warmup_ratio,
lr_scheduler_type=lr_scheduler_type,
num_train_epochs=config.num_epochs,
learning_rate=learning_rate,
bf16=bf16,
fp16=fp16,
gradient_checkpointing=False,
logging_steps=1,
output_dir="outputs",
optim="adamw_torch",
max_grad_norm=0.3,
)
# TODO add validation dataset maybe
trainer = SFTTrainer(
model,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
max_seq_length=1024,
args=training_arguments
)
trainer.train()
output_dir = os.path.join(model_save_dir, f'{model_name}_{run_id}')
trainer.save_model(output_dir)
del model
del tokenizer
return trainer