forked from qiufengqijun/mini_qwen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmini_qwen_dpo.py
114 lines (93 loc) · 3.41 KB
/
mini_qwen_dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
# 设置环境变量以优化CUDA内存分配
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# 加载分词器与模型
output_path = "results/dpo"
model_path = "results/sft/checkpoint-15461" # 从sft 2epoch模型继续训练
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 加载数据集并进行预处理
data_files = ["data/dpo/train-00000-of-00001.parquet", "data/dpo/test-00000-of-00001.parquet"]
dataset = load_dataset("parquet", data_files=data_files, split="train")
dataset = dataset.shuffle(seed=42)
# dataset = dataset.shuffle(seed=42).select(range(20))
# print(dataset[:2]["chosen"]);input()
def preprocess_dataset(examples):
prompt, chosen, rejected = [], [], []
for i in range(len(examples["prompt"])):
text = f"<|im_start|>user\n{examples['prompt'][i]}<|im_end|>\n<|im_start|>assistant\n"
prompt.append(text)
assert examples["chosen"][i][1]["role"] == "assistant"
text = f"{examples['chosen'][i][1]['content']}<|im_end|>"
chosen.append(text)
assert examples["rejected"][i][1]["role"] == "assistant"
text = f"{examples['rejected'][i][1]['content']}<|im_end|>"
rejected.append(text)
result = {"prompt": prompt, "chosen": chosen, "rejected": rejected}
return result
# 应用预处理函数
train_dataset = dataset.map(
preprocess_dataset,
batched=True,
batch_size=5000,
remove_columns=dataset.column_names,
num_proc=16,
)
# 训练参数配置
training_args = DPOConfig(
output_dir=output_path,
overwrite_output_dir=True,
learning_rate=5e-7,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
save_strategy="epoch", # 保存中间模型
save_total_limit=3,
bf16=True,
# save_only_model=True,
logging_steps=20,
)
# 初始化Trainer
trainer = DPOTrainer(
model=model,
train_dataset=train_dataset,
args=training_args,
tokenizer=tokenizer,
dataset_num_proc=16,
max_length=1024,
max_prompt_length=512,
)
# 开始训练
trainer.train()
trainer.save_model() # 保存模型
tokenizer.save_pretrained(output_path) # 保存分词器
def plot_loss(save_directory, log_history):
"""绘制训练损失曲线并保存图像"""
plt.switch_backend("agg") # 使用非交互式后端
key = "loss" # 默认使用 'loss' 作为绘图的关键字
steps, metrics = [], []
# 提取损失数据
for log_entry in log_history:
if key in log_entry:
steps.append(log_entry["step"])
metrics.append(log_entry[key])
# 绘制图像
plt.figure()
plt.plot(steps, metrics, color="#1f77b4", label="original")
plt.title(f"Training {key} of {save_directory}")
plt.xlabel("Step")
plt.ylabel(key.capitalize())
plt.legend()
# 保存图像
figure_path = os.path.join(save_directory, f"training_{key.replace('/', '_')}.png")
plt.savefig(figure_path, format="png", dpi=100)
print(f"Figure saved at: {figure_path}")
# 绘制并保存损失曲线
plot_loss(output_path, trainer.state.log_history)