-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_squad_diffmask.py
82 lines (70 loc) · 2.7 KB
/
run_squad_diffmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
# coding: utf-8
import os
import argparse
import torch
import numpy as np
import pytorch_lightning as pl
from diffmask.models.question_answering_squad_diffmask import (
BertQuestionAnsweringSquadDiffMask,
PerSampleBertQuestionAnsweringSquadDiffMask,
)
from diffmask.utils.callbacks import CallbackSquadDiffMask
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=str, default="0")
parser.add_argument(
"--model",
type=str,
default="bert-large-uncased-whole-word-masking-finetuned-squad",
)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument(
"--train_filename",
type=str,
default="./datasets/squad/train-v1.1_bert-large-uncased-whole-word-masking-finetuned-squad.json",
)
parser.add_argument(
"--val_filename",
type=str,
default="./datasets/squad/dev-v1.1_bert-large-uncased-whole-word-masking-finetuned-squad.json",
)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--gate_bias", action="store_true")
parser.add_argument("--learning_rate_alpha", type=float, default=3e-1)
parser.add_argument("--learning_rate_placeholder", type=float, default=1e-3)
parser.add_argument("--eps", type=float, default=1)
parser.add_argument("--eps_valid", type=float, default=3)
parser.add_argument("--acc_valid", type=float, default=0.0)
parser.add_argument("--placeholder", action="store_true")
parser.add_argument("--stop_train", action="store_true")
parser.add_argument(
"--gate",
type=str,
default="input",
choices=["input", "hidden", "per_sample-reinforce", "per_sample-diffmask"],
)
parser.add_argument("--layer_pred", type=int, default=-1)
hparams= parser.parse_args()
torch.manual_seed(hparams.seed)
np.random.seed(hparams.seed)
os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpu
model = BertQuestionAnsweringSquadDiffMask(hparams)
trainer = pl.Trainer(
gpus=int(hparams.gpu != ""),
progress_bar_refresh_rate=1,
max_epochs=hparams.epochs,
callbacks=[CallbackSquadDiffMask()],
checkpoint_callback=pl.callbacks.ModelCheckpoint(
filepath=os.path.join(
"outputs",
"squad-bert-{}-layer_pred={}".format(hparams.gate, hparams.layer_pred),
"{epoch}-{val_acc:.2f}-{val_f1:.2f}-{val_l0:.2f}",
),
verbose=True,
save_top_k=50,
),
)
trainer.fit(model)