-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathfinetune_orpo.py
192 lines (175 loc) · 9.39 KB
/
finetune_orpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import torch
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import ORPOConfig, ORPOTrainer
import argparse
from threads import finetune_prompts_dpo
import pandas as pd
def main():
parser = argparse.ArgumentParser(description="Finetune a base instruct/chat model using (Q)LoRA and PEFT using ORPO (RLHF)")
parser.add_argument('tuned_model', type=str,
help='The name of the resulting tuned model.')
parser.add_argument('dataset_name', type=str,
help='The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script.')
parser.add_argument('instruction_prompt', type=str,
help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English."')
parser.add_argument('--base_model', type=str, default="NousResearch/Meta-Llama-3-8B-Instruct",
help='The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct".')
parser.add_argument('--base_dataset_text_field', type=str, default="text",
help="The dataset's column name containing the actual text to translate. Defaults to text")
parser.add_argument('--base_dataset_rank_field', type=str, default="rank",
help="The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank")
parser.add_argument('--base_dataset_id_field', type=str, default="message_id",
help="The dataset's column name containing the id of a text. Defaults to message_id")
parser.add_argument('--base_dataset_parent_field', type=str, default="parent_id",
help="The dataset's column name containing the parent id of a text. Defaults to parent_id")
parser.add_argument('--quant8', action='store_true',
help='Finetunes the model in 8 bits. Requires more memory than the default 4 bit.')
parser.add_argument('--noquant', action='store_true',
help='Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit.')
parser.add_argument('--max_seq_length', type=int, default=512,
help='The maximum sequence length to use in finetuning. Should most likely line up with your base model\'s default max_seq_length. Default is 512.')
parser.add_argument('--max_prompt_length', type=int, default=512,
help='The maximum length of the prompts to use. Default is 512.')
parser.add_argument('--num_train_epochs', type=int, default=2,
help='Number of epochs to use. 2 is default and has been shown to work well.')
parser.add_argument('--batch_size', type=int, default=4,
help='The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4')
parser.add_argument('--threads_output_name', type=str, default=None,
help='If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub.')
parser.add_argument('--thread_template', type=str, default="threads/template_default.txt",
help='A file containing the thread template to use. Default is threads/template_fefault.txt')
parser.add_argument('--max_steps', type=int, default=-1,
help='The maximum number of steps to run ORPO for. Default is -1 which will run the data through fully for the number of epochs but this will be very time-consuming.')
parser.add_argument('--padding', type=str, default="left",
help='What padding to use, can be either left or right.')
args = parser.parse_args()
base_model = args.base_model
tuned_model = args.tuned_model
dataset_name = args.dataset_name
instruction_prompt = args.instruction_prompt
base_dataset_text_field = args.base_dataset_text_field
base_dataset_rank_field = args.base_dataset_rank_field
base_dataset_id_field = args.base_dataset_id_field
base_dataset_parent_field = args.base_dataset_parent_field
quant8 = args.quant8
noquant = args.noquant
max_seq_length = args.max_seq_length
num_train_epochs = args.num_train_epochs
per_device_train_batch_size = args.batch_size
threads_output_name = args.threads_output_name
thread_template_file = args.thread_template
max_prompt_length = args.max_prompt_length
max_steps = args.max_steps
padding = args.padding
# Check for HF_TOKEN
if 'HF_TOKEN' not in os.environ:
print("[WARNING] Environment variable 'HF_TOKEN' is not set!")
user_input = input("Do you want to continue? (yes/no): ").strip().lower()
if user_input != "yes":
print("Terminating the program.")
exit()
# Load the base translated dataset
if os.path.isdir(dataset_name):
dataset = load_from_disk(dataset_name)
else:
dataset = load_dataset(dataset_name)
# Load base tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# Get the template
with open(thread_template_file, 'r', encoding="utf8") as f:
chat_template = f.read()
# Compute the threads
prompts = {k: [] for k in dataset.keys()}
for fold in prompts:
print(f"[---- LLaMa2Lang ----] Generating prompts using chat template {thread_template_file} for fold {fold}")
templated_prompts = finetune_prompts_dpo.create_prompts(dataset[fold], tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, instruction_prompt, chat_template)
prompts[fold] = Dataset.from_pandas(pd.DataFrame(data=templated_prompts))
prompts = DatasetDict(prompts)
# Check if we need to write out
if threads_output_name is not None:
# Also do the other folds
print(f"[---- LLaMa2Lang ----] Writing out ORPO thread prompts dataset to {threads_output_name}")
if os.path.isdir(threads_output_name):
prompts.save_to_disk(threads_output_name)
else:
prompts.push_to_hub(threads_output_name)
if noquant:
# Load base model
model = AutoModelForCausalLM.from_pretrained(base_model, device_map={"": 0}, trust_remote_code=True)
elif quant8:
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="qat8",
bnb_8bit_compute_dtype=getattr(torch, "float32"),
bnb_8bit_use_double_quant=False
)
# Load base model
model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True)
else:
# Set up quantization config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=getattr(torch, "float16"),
bnb_4bit_use_double_quant=True,
)
# Load base model
model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True)
model.config.use_cache = False
model.config.pretraining_tp = 1
# Load base tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# Just like Alpaca, because we allow to add history in the prompts, it makes more sense to do left-padding to have the most informative text at the end.
# In this case, we need a different pad token than EOS because we actually do _not_ pad end of sentence.
if padding == 'left':
tokenizer.pad_token_id = 0
else:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = padding
orpo_config = ORPOConfig(
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=max_steps,
save_strategy="no",
logging_steps=1,
output_dir="./results",
optim="paged_adamw_32bit",
warmup_steps=100,
bf16=True,
report_to=None,
remove_unused_columns=False,
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
target_modules='all-linear',
)
trainer = ORPOTrainer(
model,
args=orpo_config,
train_dataset=prompts['train'],
tokenizer=tokenizer,
)
# Before starting training, free up memory
torch.cuda.empty_cache()
# Train the ORP model
trainer.train()
# Check if output location is a valid directory
print(f"[---- LLaMa2Lang ----] Writing model and tokenizer out to {tuned_model}")
if os.path.isdir(tuned_model):
trainer.model.save_to_disk(tuned_model)
trainer.tokenizer.save_to_disk(tuned_model)
else:
# Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken
trainer.model.push_to_hub(tuned_model)
trainer.tokenizer.push_to_hub(tuned_model)
if __name__ == "__main__":
main()