forked from cofe-ai/nanoLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain_bert.py
81 lines (70 loc) · 1.98 KB
/
pretrain_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from accelerate import Accelerator
from omegaconf import open_dict
import hydra
import torch
import time
import os
from cofellm.train_utils import (
train,
predict,
eval
)
from cofellm.model.utils import (
get_config,
get_lr_scheduler,
get_optimizer,
get_tokenizer,
get_model,
get_dataloaders,
)
from cofellm.utils import setup_basics
from cofellm.arguments import get_argparse
import pdb
def main():
args = get_argparse().parse_args()
# args.use_mup = True
print("#############" + str(args.use_mup))
accelerator = Accelerator(
cpu=args.device == "cpu",
mixed_precision=args.precision,
)
logger = setup_basics(accelerator, args)
config = get_config(args)
model = get_model(args, config)
tokenizer = get_tokenizer(args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_lr_scheduler(optimizer, args, logger)
train_dataloader, test_dataloader = get_dataloaders(tokenizer, config, args)
logger.log_args(args)
(
model,
optimizer,
lr_scheduler,
train_dataloader,
test_dataloader,
) = accelerator.prepare(
model, optimizer, lr_scheduler, train_dataloader, test_dataloader
)
for n, p in model.named_parameters():
print(n, p.size(), p.mean().item(), p.var().item())
if args.compile:
model = torch.compile(model)
args.current_train_step = 1
args.current_epoch = 1
args.last_log = time.time()
if args.eval_only:
model.eval()
with torch.no_grad():
eval(model, test_dataloader, logger, args, tokenizer)
elif args.predict_only:
model.eval()
with torch.no_grad():
predict(model, test_dataloader, logger,
args, tokenizer)
else:
train(model, train_dataloader, test_dataloader, accelerator,
lr_scheduler, optimizer, logger, args, tokenizer)
logger.finish()
if __name__ == "__main__":
main()