-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpublish_checkpoint.py
79 lines (60 loc) · 2.25 KB
/
publish_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import sys
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.tokenization_utils import PaddingStrategy, TruncationStrategy
from transformers.trainer_utils import get_last_checkpoint
from peft import PeftModel, PeftConfig
pretrain_output_dir="Japanese-TinyLlama-1.1B-1.0T"
checkpoint_dir="output_dir/"
last_checkpoint = get_last_checkpoint(checkpoint_dir)
if last_checkpoint is None and len(os.listdir(checkpoint_dir)) > 0:
raise ValueError(
f"Output directory ({checkpoint_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
print(last_checkpoint)
tokenizer = LlamaTokenizer.from_pretrained(last_checkpoint)
print(tokenizer)
peft_model_path = last_checkpoint
config = PeftConfig.from_pretrained(peft_model_path)
print(config.base_model_name_or_path)
model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="cpu")
# To reconstruct Japanese TinyLLaMa model,
# 1. load base TinyLlama model
# 2. resize embeddings
# 3. load LoRA weights
model_vocab_size = model.get_output_embeddings().weight.size(0)
if (model_vocab_size != len(tokenizer)):
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
print(model)
model = PeftModel.from_pretrained(model, peft_model_path)
print("peft", model)
# merge
merged_model = model.merge_and_unload()
# revert embedding size
model.resize_token_embeddings(len(tokenizer))
# save
os.makedirs(pretrain_output_dir, exist_ok=True)
merged_model.save_pretrained(pretrain_output_dir)
tokenizer.save_pretrained(pretrain_output_dir)
print("Wrote merged pretrained model to: ", pretrain_output_dir)
# eval test
text = "ずんだもんは、 東北に住むかわいい妖精です。"
inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")
print(inputs)
with torch.no_grad():
output_ids = merged_model.generate(
**inputs,
max_new_tokens=2048,
min_new_tokens=250,
do_sample=True,
num_beams=1,
temperature=0.8,
no_repeat_ngram_size=2,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0])
print(output)